#     Copyright 2007-8 Jim Bublitz <jbublitz@nwinternet.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the
# Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA

scalarTypes = ("int", "long int", "short int", "signed int", "unsigned int", "unsigned long int",\
               "unsigned short",  "unsigned short int", "char", "unsigned char", "signed char",\
               "float", "double", "long", "long long", "unsigned long", "signed long", "unsigned long long",\
               "signed long long", "short", "unsigned", "signed", "bool", "void", "wchar_t")
    
def computeScopes (stateInfo, symbolData, modData):
    scopeComputer = ScopeComputer (stateInfo, symbolData, modData)
    for file in symbolData.files:
        fileList = symbolData.fileData [file]
        scopeComputer.context = ''
        for object in fileList:
            scopeComputer.compute (object)    

class ScopeComputer:
    def __init__ (self, stateInfo, symbolData, modData):
        self.stateInfo  = stateInfo
        self.symbolData = symbolData
        self.modData    = modData
        self.index      = self.symbolData.index
        self.templateParams = []
        
    def primaryType (self, argtype):
        # get the unmodified (almost) identifier
        all  = len (argtype)
        star = argtype.find ("*")
        amp  = argtype.find ("&")
        if star < 0:
            star = all
        if amp < 0:
            amp = all

        return argtype [:min (star, amp, all)]        
        
    def compute (self, obj):
        if obj.objectType == 'class':
            if obj.bases:
                self.computeBaseScopes (obj)
            self.stateInfo.pushClass (obj.name, obj)
            if obj.templateParams:
                self.templateParams = obj.templateParams.split (',')
        elif obj.objectType == 'namespace':
            self.stateInfo.pushNamespace (obj.name, obj)
        elif obj.objectType == 'endclass':
            self.stateInfo.popClass (True)
            self.templateParams = []
        elif obj.objectType == 'endnamespace':
            self.stateInfo.popNamespace (True)
        elif obj.objectType == 'typedef':
            self.computeTypedefScopes (obj)
        elif obj.objectType == 'function':
            self.computeFunctionScopes (obj)
        elif obj.objectType == 'variable':
            self.computeArgumentScope (obj.variable)
            
        self.context = self.stateInfo.scope
        
    def computeBaseScopes (self, obj):
        # the scope of base classes for a class
        newBases = []
        for base in obj.bases:
            if '::' in base:
                parts = base.split ('::')
                base0  = parts [-1]
                objScope = '::'.join (parts [:-1])
            else:
                base0 = base
                objScope = ''
                
            if '<' in base:
                obj.undefBase = True
                newBases.append (base)
            else:
                scope = self.computeObjectScope (base0, objScope, ['class', 'typedef'])
    
                if scope == None:
                    obj.undefBase = True
                    newBases.append (base)
                elif scope == '':
                    newBases.append (base)
                else:
                    newBases.append ('::'.join ([scope, base0]))

        obj.bases = newBases
        
    def sortByLength (self, a, b):
        # longest first, so > returns -1
        la = len (a)
        lb = len (b)
        if la == lb:
            return 0
        if la > lb:
            return -1
        return 1
        
    def matchBaseClass (self, base, matches):
        # check if a base class of the current class has the
        # symbol being scoped
        for match in matches:
            if base == match.scope:
                return base
            elif match.scope.endswith ('::%s' % base):
                return match.scope
                
        if base in self.symbolData.parents:
            for nextBase in self.symbolData.parents [base]:
                found = self.matchBaseClass (nextBase, matches)
                if found:
                    return found
        return None        
                    
    def computeObjectScope (self, objType, objScope, types = []):
        # if not in a class or namespace, the object would have to
        # correctly scoped in the h file
        if not self.context:
            return objScope
            
        # get objects with this name
        if objScope:
            matches = [object for object in self.index.getMatchingObjects (objType, types) if object.scope.endswith (objScope) or object.scope == '']
        else:
            matches = self.index.getMatchingObjects (objType, types)
            
        # if no matches, name has not been defined
        if matches == []:
            return None
            
        # first check the class hierarchy for a base class with matching objType
        # starting with the current class name
        className = self.stateInfo.currentClass ()
        if className:
             match = self.matchBaseClass (className, matches)
             if match:
                 return match
            
        candidates = []
        
        # next, check enclosing scopes (including global)
        for match in matches:
            # if the object is partially scoped
            # (but we still look for the narrowest scope)
            if objScope and objScope == match.scope:
                candidates.append (match.scope)
                break
            elif match.scope.endswith ('::%s' % objScope):
                candidates.append (match.scope)

            # global ('') will match any context, ctors have wider scope
            # than other class members
            if match.scope == self.context:
                candidates.append (match.scope)
                break
            elif self.context.startswith ('%s::' % match.scope) or match.scope == '':
                candidates.append (match.scope)                

        # 1. if more than one match, take the one with the
        #    longest scope match to the current context
        # 2. if one match, that has the scope
        # 3. if no matches, the symbol is undefined
        # (looking for the most restrictive enclosing scope)
        n = len (candidates) 
        if n > 1:
            candidates.sort (self.sortByLength)
            return candidates [0]
        elif n == 1:
            return candidates [0]
        return None       

    def computeTypedefScopes (self, obj):
        if not '<' in obj.argumentType:
            basetype0 = self.primaryType (obj.argumentType)
            if '::' in basetype0:
                parts = basetype0.split ('::')
                basetype = parts [-1]
                objScope = '::'.join (parts [:-1])
            else:
                basetype = basetype0
                objScope = obj.scope
                
            scope = self.computeObjectScope (basetype, objScope, ['mappedtype', 'exception', 'class', 'typedef', 'enum'])
            if scope == None:
                obj.attributes.mtNeeded = True
            elif scope:
                obj.argumentType = obj.argumentType.replace (basetype0, '::'.join ([scope, basetype]))
        else:
            self.computeTemplateScope0 (obj)
                    
    def computeFunctionScopes (self, obj):
        allArgs = obj.arguments + obj.returns
        for argument in allArgs:
            if self.templateParams and self.primaryType (argument.argumentType) in self.templateParams:
                continue
            self.computeArgumentScope (argument)
                
            obj.attributes.mcNeeded = obj.attributes.mcNeeded or argument.attributes.mcNeeded
            obj.attributes.mtNeeded = obj.attributes.mtNeeded or argument.attributes.mtNeeded
            obj.attributes.ioNeeded = obj.attributes.ioNeeded or argument.attributes.ioNeeded
           
    def computeArgumentScope (self, arg):
        # this computes the scope of an argument
        argumentType = arg.argumentType
        if argumentType in ['SIP_SLOT_CON ()', 'SIP_RXOBJ_CON']:
            return
            
        basetype = self.primaryType (arg.argumentType)
        if basetype in scalarTypes or basetype in ['ctor', 'dtor']:
            return
            
        if basetype in self.symbolData.definedScalar:
            if basetype != arg.argumentType:
                arg.attributes.ioNeeded == True
            return
                    
        # templates are handled differently
        if '<' in basetype:
            self.computeTemplateScope0 (arg)
        else:
            scope = self.computeObjectScope (basetype, arg.scope, ['mappedtype', 'exception', 'class', 'typedef', 'enum'])
            if scope == None:
                arg.attributes.mtNeeded = True
            else:
                arg.scope = scope
                
        if arg.defaultValue:
            self.computeDefaultScopes (arg)

    def collectParams (self, template, p, t):
        # collect all of the parameters for template type
        # a Template object handles one or more non-template types
        # if a parameter is itself a template, an additional
        # Template object is attached that holds the template-type
        # parameters info (and calls this method recursively)
        # duplicate types (eg - QMap<QString,QString> are only
        # scoped once (the 'p' dict does that automatically)
        t [template.name] = []
        params = template.params.split (',')
        for param in params:
            basetype = self.primaryType (param)
            t [template.name].append (basetype)
            if '<' in param and not 'static_cast' in param:
                self.collectParams (template.template, p, t)
            elif '>' in param:
                continue
            else:
                p [basetype] = ''                     
                
    def computeTemplateScope0 (self, arg):
        params    = {}
        templates = {}
        # get all of the unique types
        self.collectParams (arg.template, params, templates)
        
        for param in params:
            if param in scalarTypes\
                or param in self.symbolData.definedScalar\
                or '<' in param:
                continue
                
            if '::' in param:
                parts = param.split ('::')
                basetype = parts [-1]
                objScope = '::'.join (parts [:-1])
            else:
                basetype = param
                objScope = arg.scope
            
            # use computeObjectScope to scope each of the parameters
            scope = self.computeObjectScope (basetype, objScope, ['mappedtype', 'class', 'typedef', 'enum'])
            if scope == None:
                matches = self.index.getMatchingObjects (arg.template.name, ['template'])
                if matches != []:
                    for obj in matches:
                        if obj.template and param in obj.template.params:
                            break
                else:
                    arg.attributes.mtNeeded = True
                    
            elif scope:                
                arg.argumentType = arg.argumentType.replace (param, '::'.join ([scope, basetype]))
                if hasattr (arg, 'defaultValue') and arg.defaultValue:
                    arg.defaultValue = arg.defaultValue.replace (param, '::'.join ([scope, basetype]))
        
            
        # also need to check template.name and param types against available
        # template and mapped types  
        self.verifyTemplate (arg)
             
    def deconstructTemplate (self, template):
        typelist = []
        params = template.params.split (',')
        for param in params:
            if ' ' in param:
                param = param.split () [0].strip ()
            if self.primaryType (param) in scalarTypes:
                typelist.append (param)
            else:
                base = self.primaryType (param)
                typelist.append ('nonscalar' + param [len (base):])
        return typelist
                    
    def verifyTemplate (self, arg):
        pos = arg.argumentType.find ('<')
        basetype = arg.argumentType [:pos]
        matches = self.index.getMatchingObjects (basetype, ['template'])
        if not matches:
            arg.attributes.mtNeeded = True
            return
            
        argTypelist = self.deconstructTemplate (arg.template)
            
        for match in matches:            
            if (match.template and argTypelist == match.typelist)\
                or (not match.template and self.primaryType (arg.argumentType) == match.typelist):
                return
                
        arg.attributes.mtNeeded = True
        return             

    def computeDefaultScopes (self, arg):
        # scan the list of types the parser found in the default value expression
        for item in arg.defaultTypes:
            if not '<' in item:
                if item.endswith ('()'):
                    basetype0 = item [:-2]
                    types    = ['function', 'typedef']
                else:
                    basetype0 = self.primaryType (item)                
                    types    = ['mappedtype', 'exception', 'class', 'typedef', 'enum']
                    
                objScope = arg.scope
            
                if '::' in basetype0:
                    parts = basetype0.split ('::')
                    basetype = parts [-1]
                    objScope = '::'.join (parts [:-1])
                else:
                    basetype = basetype0
                
                if not 'function' in types:
                    # the argument type might be an enum, and the default value an enumerator
                    enumMatches = [obj for obj in self.index.getMatchingObjects (self.primaryType (arg.argumentType), ['enum', 'typedef']) if obj.scope == arg.scope]
                    if enumMatches:
                        e = enumMatches [0]
                        if e.objectType == 'typedef' and 'QFlags' in e.argumentType:
                            enumName = e.argumentType [7:-1]
                        else:
                            enumName = e.name
                    else:
                        enumName = ''
                        
                    matches = [obj for obj in self.index.getMatchingObjects (basetype0, ['enumerator']) if obj.enum == enumName or obj.scope == arg.scope]
                    if matches and arg.scope:
                        arg.defaultValue = arg.defaultValue.replace (basetype0, '::'.join ([arg.scope, basetype]))
                        continue
                
                scope = self.computeObjectScope (basetype, objScope, types)
                if scope == None:
                    arg.attributes.mtNeeded = True
                elif scope:
                    arg.defaultValue = arg.defaultValue.replace (basetype0, '::'.join ([objScope, basetype]))
            else:
                # fix!! template defaults get handled here
                pass
