Skip to content

Source Loader

SourceLoader

Bases: object

Source code in src/core/sourceLoader.py
class SourceLoader(object):

    def __init__(self):
        self._patch_list = []
        self._npatch = 0
        self._source_list = []
        self._nsource = 0
        self._match_dict = {}
        self._nmatch = 0
        self._bit_vector = bitarray.bitarray(common.bloomfilter_size)
        self._results = {}
        self._source_hashes = []

    def traverse(self, source_path, patch):
        '''
        Traverse source files
        '''
        common.verbose_print('[+] traversing source files')
        start_time = time.time()
        self._patch_list = patch.items()
        self._npatch = patch.length()

        if os.path.isfile(source_path):
            magic_type = common.file_type(source_path)
            common.verbose_print('  [-] %s: %s' % (source_path, magic_type))
            common.verbose_print(f'Magic type :{magic_type}')
            if magic_type.startswith('text'):
                main_type, sub_type = magic_type.split('/')
                magic_ext = self._get_file_type(sub_type)
                self._process(source_path, magic_ext)
        elif os.path.isdir(source_path):
            for root,dirs,files in os.walk(source_path):
                for file in files:
                    file_path = os.path.join(root, file)
                    magic_type = common.file_type(file_path)
                    common.verbose_print('  [-] %s: %s' % (file_path, magic_type))
                    if magic_type.startswith('text'):
                        main_type, sub_type = magic_type.split('/')
                        magic_ext = self._get_file_type(sub_type)
                        self._process(file_path, magic_ext)

        elapsed_time = time.time() - start_time
        common.verbose_print('[+] %d possible matches ... %.1fs\n' % (self._nmatch, elapsed_time))
        return self._nmatch

    def _process(self, source_path, magic_ext):
        '''
        Normalize a source file and build a Bloom filter for queries
        '''
        source_file = open(source_path, 'r')
        source_orig_lines = source_file.read()
        source_file.close()

        source_norm_lines = self._normalize(source_orig_lines, magic_ext)
        self._query_bloomfilter(source_norm_lines, magic_ext)
#             source_norm_lines = re.split('\n', source_norm_lines)
#             source_orig_lines = re.split('\n', source_orig_lines)
#             self._source_list.append(common.SourceInfo(source_path, magic_ext, source_orig_lines, source_norm_lines))
#             self._nsource += 1

    def _normalize(self, source, ext):
        '''
        Normalize a source file
        '''
        # Language-specific optimization
        if ext==common.FileExt.C or ext==common.FileExt.Java:
            norm_lines = []
            for c in common.c_regex.finditer(source):
                if c.group('noncomment'):
                    norm_lines.append(c.group('noncomment'))
                elif c.group('multilinecomment'):
                    newlines_cnt = c.group('multilinecomment').count('\n')
                    while newlines_cnt:
                        norm_lines.append('\n')
                        newlines_cnt -= 1
            source = ''.join(norm_lines)
        elif ext==common.FileExt.Python:
            source = re.sub(re.compile("'''.*?'''", re.DOTALL ), "", source) # Remove multi-line comments with single quotes
            source = re.sub(re.compile('""".*?"""', re.DOTALL ), "", source) # Remove multi-line comments with double quotes
            source = re.sub(re.compile("#.*?\n"), "", source) # Remove single line comments
        elif ext==common.FileExt.ShellScript:
            source = ''.join([c.group('noncomment') for c in common.shellscript_regex.finditer(source) if c.group('noncomment')])
        elif ext==common.FileExt.Perl:
            source = ''.join([c.group('noncomment') for c in common.perl_regex.finditer(source) if c.group('noncomment')])
        elif ext==common.FileExt.PHP:
            norm_lines = []
            for c in common.php_regex.finditer(source):
                if c.group('noncomment'):
                    norm_lines.append(c.group('noncomment'))
                elif c.group('multilinecomment'):
                    newlines_cnt = c.group('multilinecomment').count('\n')
                    while newlines_cnt:
                        norm_lines.append('\n')
                        newlines_cnt -= 1
            source = ''.join(norm_lines)
        elif ext==common.FileExt.Ruby:
            norm_lines = []
            for c in common.ruby_regex.finditer(source):
                if c.group('noncomment'):
                    norm_lines.append(c.group('noncomment'))
                elif c.group('multilinecomment'):
                    newlines_cnt = c.group('multilinecomment').count('\n')
                    while newlines_cnt:
                        norm_lines.append('\n')
                        newlines_cnt -= 1
            source = ''.join(norm_lines)

        # Remove whitespaces except newlines
        source = common.whitespaces_regex.sub("", source)
        # Convert into lowercases
        return source.lower()

    def _query_bloomfilter(self, source_norm_lines, magic_ext):
        source_norm_lines = source_norm_lines.split()
        if len(source_norm_lines) < common.ngram_size:
#             common.verbose_print('      - skipped (%d lines)' % len(source_norm_lines))
            return False

        self._bit_vector.setall(0)
        num_ngram = len(source_norm_lines) - common.ngram_size + 1
#         print('Num_ngram: ', num_ngram, '\n')
        is_vuln_source = False
        num_ngram_processed = 0
        for i in range(0, num_ngram):
            if num_ngram_processed > common.bloomfilter_size/common.min_mn_ratio:
#                 common.verbose_print('      - split Bloom filters (%d n-grams)' % num_ngram_processed)
                for patch_id in range(0, self._npatch):
                    hash_list_old = self._patch_list['Old_norm_lines']
                    is_match = True
                    for h in hash_list:
                        if not self._bit_vector[h]:
#                             print('No Match')
                            is_match = False
                    if is_match:
#                         print('Matched')
                        is_vuln_source = True
                        self._match_dict[patch_id].append(self._nsource)
#                         common.verbose_print('      - match (patch #%d : source #%d)' % (patch_id, self._nsource))
                        self._nmatch += 1
                num_ngram_processed = 0
                self._bit_vector.setall(0)

            ngram = ''.join(source_norm_lines[i:i+common.ngram_size])
#             print('using size ' , common.ngram_size)
#             print(ngram)
            hash1 = common.fnv1a_hash(ngram) & (common.bloomfilter_size-1)
            hash2 = common.djb2_hash(ngram) & (common.bloomfilter_size-1)
            hash3 = common.sdbm_hash(ngram) & (common.bloomfilter_size-1)
            self._bit_vector[hash1] = 1
            self._bit_vector[hash2] = 1
            self._bit_vector[hash3] = 1
#             print(hash1, ' - ', hash2, ' - ', hash3, '\n')
            num_ngram_processed += 1
            self._source_hashes.append([ngram, [hash1, hash2, hash3]])

        for patch_id in range(0, self._npatch):  
#             print('Doing some matching ', patch_id)
            hash_list = self._patch_list[patch_id].hash_list
            is_match = True
            i = 0
            seq = 0
            for h in hash_list:
#                 print('hash_list[h]= ', h)
                if i == 3:
                    i = 0
                    seq += 1

                if patch_id not in self._match_dict:
                    self._match_dict[patch_id] = {}

                if seq not in self._match_dict[patch_id]:
                    self._match_dict[patch_id][seq] = {}

#                 print('self._bit_vector[h] = ', self._bit_vector[h])
                if not self._bit_vector[h]:
                    is_match = False
                    self._results[h] = {}
                    self._results[h]['Match'] = False
                    self._match_dict[patch_id][seq][h] = False
                else:
                    self._results[h] = {}
                    self._results[h]['Match'] = True
                    self._match_dict[patch_id][seq][h] = True

                i += 1

    def _get_file_type(self, sub_type):
        '''
        Determine a file type based upon sub_type (magic module)
        '''
        magic_ext = None
        if sub_type.startswith('x-c'):
            magic_ext = common.FileExt.C
        elif sub_type == 'x-java':
            magic_ext = common.FileExt.Java
        elif sub_type == 'x-shellscript':
            magic_ext = common.FileExt.ShellScript
        elif sub_type == 'x-perl':
            magic_ext = common.FileExt.Perl
        elif sub_type == 'x-python':
            magic_ext = common.FileExt.Python
        elif sub_type == 'x-php':
            magic_ext = common.FileExt.PHP
        elif sub_type == 'x-ruby':
            magic_ext = common.FileExt.Ruby
        else:
            magic_ext = common.FileExt.Text
        return magic_ext

    def items(self):
        return self._source_list

    def length(self):
        return self._nsource

    def match_items(self):
        return self._match_dict

    def results(self):
        return self._results

    def source_hashes(self):
        return self._source_hashes

traverse(source_path, patch)

Traverse source files

Source code in src/core/sourceLoader.py
def traverse(self, source_path, patch):
    '''
    Traverse source files
    '''
    common.verbose_print('[+] traversing source files')
    start_time = time.time()
    self._patch_list = patch.items()
    self._npatch = patch.length()

    if os.path.isfile(source_path):
        magic_type = common.file_type(source_path)
        common.verbose_print('  [-] %s: %s' % (source_path, magic_type))
        common.verbose_print(f'Magic type :{magic_type}')
        if magic_type.startswith('text'):
            main_type, sub_type = magic_type.split('/')
            magic_ext = self._get_file_type(sub_type)
            self._process(source_path, magic_ext)
    elif os.path.isdir(source_path):
        for root,dirs,files in os.walk(source_path):
            for file in files:
                file_path = os.path.join(root, file)
                magic_type = common.file_type(file_path)
                common.verbose_print('  [-] %s: %s' % (file_path, magic_type))
                if magic_type.startswith('text'):
                    main_type, sub_type = magic_type.split('/')
                    magic_ext = self._get_file_type(sub_type)
                    self._process(file_path, magic_ext)

    elapsed_time = time.time() - start_time
    common.verbose_print('[+] %d possible matches ... %.1fs\n' % (self._nmatch, elapsed_time))
    return self._nmatch