# Copyright 2013-2016 The Meson development team

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import sys
import os
import stat
import struct
import shutil
import subprocess
import typing as T

from ..mesonlib import OrderedSet, generate_list, Popen_safe

SHT_STRTAB = 3
DT_NEEDED = 1
DT_RPATH = 15
DT_RUNPATH = 29
DT_STRTAB = 5
DT_SONAME = 14
DT_MIPS_RLD_MAP_REL = 1879048245

# Global cache for tools
INSTALL_NAME_TOOL = False

class DataSizes:
    def __init__(self, ptrsize: int, is_le: bool) -> None:
        if is_le:
            p = '<'
        else:
            p = '>'
        self.Half = p + 'h'
        self.HalfSize = 2
        self.Word = p + 'I'
        self.WordSize = 4
        self.Sword = p + 'i'
        self.SwordSize = 4
        if ptrsize == 64:
            self.Addr = p + 'Q'
            self.AddrSize = 8
            self.Off = p + 'Q'
            self.OffSize = 8
            self.XWord = p + 'Q'
            self.XWordSize = 8
            self.Sxword = p + 'q'
            self.SxwordSize = 8
        else:
            self.Addr = p + 'I'
            self.AddrSize = 4
            self.Off = p + 'I'
            self.OffSize = 4

class DynamicEntry(DataSizes):
    def __init__(self, ifile: T.BinaryIO, ptrsize: int, is_le: bool) -> None:
        super().__init__(ptrsize, is_le)
        self.ptrsize = ptrsize
        if ptrsize == 64:
            self.d_tag = struct.unpack(self.Sxword, ifile.read(self.SxwordSize))[0]
            self.val = struct.unpack(self.XWord, ifile.read(self.XWordSize))[0]
        else:
            self.d_tag = struct.unpack(self.Sword, ifile.read(self.SwordSize))[0]
            self.val = struct.unpack(self.Word, ifile.read(self.WordSize))[0]

    def write(self, ofile: T.BinaryIO) -> None:
        if self.ptrsize == 64:
            ofile.write(struct.pack(self.Sxword, self.d_tag))
            ofile.write(struct.pack(self.XWord, self.val))
        else:
            ofile.write(struct.pack(self.Sword, self.d_tag))
            ofile.write(struct.pack(self.Word, self.val))

class SectionHeader(DataSizes):
    def __init__(self, ifile: T.BinaryIO, ptrsize: int, is_le: bool) -> None:
        super().__init__(ptrsize, is_le)
        is_64 = ptrsize == 64

# Elf64_Word
        self.sh_name = struct.unpack(self.Word, ifile.read(self.WordSize))[0]
# Elf64_Word
        self.sh_type = struct.unpack(self.Word, ifile.read(self.WordSize))[0]
# Elf64_Xword
        if is_64:
            self.sh_flags = struct.unpack(self.XWord, ifile.read(self.XWordSize))[0]
        else:
            self.sh_flags = struct.unpack(self.Word, ifile.read(self.WordSize))[0]
# Elf64_Addr
        self.sh_addr = struct.unpack(self.Addr, ifile.read(self.AddrSize))[0]
# Elf64_Off
        self.sh_offset = struct.unpack(self.Off, ifile.read(self.OffSize))[0]
# Elf64_Xword
        if is_64:
            self.sh_size = struct.unpack(self.XWord, ifile.read(self.XWordSize))[0]
        else:
            self.sh_size = struct.unpack(self.Word, ifile.read(self.WordSize))[0]
# Elf64_Word
        self.sh_link = struct.unpack(self.Word, ifile.read(self.WordSize))[0]
# Elf64_Word
        self.sh_info = struct.unpack(self.Word, ifile.read(self.WordSize))[0]
# Elf64_Xword
        if is_64:
            self.sh_addralign = struct.unpack(self.XWord, ifile.read(self.XWordSize))[0]
        else:
            self.sh_addralign = struct.unpack(self.Word, ifile.read(self.WordSize))[0]
# Elf64_Xword
        if is_64:
            self.sh_entsize = struct.unpack(self.XWord, ifile.read(self.XWordSize))[0]
        else:
            self.sh_entsize = struct.unpack(self.Word, ifile.read(self.WordSize))[0]

class Elf(DataSizes):
    def __init__(self, bfile: str, verbose: bool = True) -> None:
        self.bfile = bfile
        self.verbose = verbose
        self.sections = []  # type: T.List[SectionHeader]
        self.dynamic = []   # type: T.List[DynamicEntry]
        self.open_bf(bfile)
        try:
            (self.ptrsize, self.is_le) = self.detect_elf_type()
            super().__init__(self.ptrsize, self.is_le)
            self.parse_header()
            self.parse_sections()
            self.parse_dynamic()
        except (struct.error, RuntimeError):
            self.close_bf()
            raise

    def open_bf(self, bfile: str) -> None:
        self.bf = None
        self.bf_perms = None
        try:
            self.bf = open(bfile, 'r+b')
        except PermissionError as e:
            self.bf_perms = stat.S_IMODE(os.lstat(bfile).st_mode)
            os.chmod(bfile, stat.S_IREAD | stat.S_IWRITE | stat.S_IEXEC)
            try:
                self.bf = open(bfile, 'r+b')
            except Exception:
                os.chmod(bfile, self.bf_perms)
                self.bf_perms = None
                raise e

    def close_bf(self) -> None:
        if self.bf is not None:
            if self.bf_perms is not None:
                os.fchmod(self.bf.fileno(), self.bf_perms)
                self.bf_perms = None
            self.bf.close()
            self.bf = None

    def __enter__(self) -> 'Elf':
        return self

    def __del__(self) -> None:
        self.close_bf()

    def __exit__(self, exc_type: T.Any, exc_value: T.Any, traceback: T.Any) -> None:
        self.close_bf()

    def detect_elf_type(self) -> T.Tuple[int, bool]:
        data = self.bf.read(6)
        if data[1:4] != b'ELF':
            # This script gets called to non-elf targets too
            # so just ignore them.
            if self.verbose:
                print(f'File {self.bfile!r} is not an ELF file.')
            sys.exit(0)
        if data[4] == 1:
            ptrsize = 32
        elif data[4] == 2:
            ptrsize = 64
        else:
            sys.exit(f'File {self.bfile!r} has unknown ELF class.')
        if data[5] == 1:
            is_le = True
        elif data[5] == 2:
            is_le = False
        else:
            sys.exit(f'File {self.bfile!r} has unknown ELF endianness.')
        return ptrsize, is_le

    def parse_header(self) -> None:
        self.bf.seek(0)
        self.e_ident = struct.unpack('16s', self.bf.read(16))[0]
        self.e_type = struct.unpack(self.Half, self.bf.read(self.HalfSize))[0]
        self.e_machine = struct.unpack(self.Half, self.bf.read(self.HalfSize))[0]
        self.e_version = struct.unpack(self.Word, self.bf.read(self.WordSize))[0]
        self.e_entry = struct.unpack(self.Addr, self.bf.read(self.AddrSize))[0]
        self.e_phoff = struct.unpack(self.Off, self.bf.read(self.OffSize))[0]
        self.e_shoff = struct.unpack(self.Off, self.bf.read(self.OffSize))[0]
        self.e_flags = struct.unpack(self.Word, self.bf.read(self.WordSize))[0]
        self.e_ehsize = struct.unpack(self.Half, self.bf.read(self.HalfSize))[0]
        self.e_phentsize = struct.unpack(self.Half, self.bf.read(self.HalfSize))[0]
        self.e_phnum = struct.unpack(self.Half, self.bf.read(self.HalfSize))[0]
        self.e_shentsize = struct.unpack(self.Half, self.bf.read(self.HalfSize))[0]
        self.e_shnum = struct.unpack(self.Half, self.bf.read(self.HalfSize))[0]
        self.e_shstrndx = struct.unpack(self.Half, self.bf.read(self.HalfSize))[0]

    def parse_sections(self) -> None:
        self.bf.seek(self.e_shoff)
        for _ in range(self.e_shnum):
            self.sections.append(SectionHeader(self.bf, self.ptrsize, self.is_le))

    def read_str(self) -> bytes:
        arr = []
        x = self.bf.read(1)
        while x != b'\0':
            arr.append(x)
            x = self.bf.read(1)
            if x == b'':
                raise RuntimeError('Tried to read past the end of the file')
        return b''.join(arr)

    def find_section(self, target_name: bytes) -> T.Optional[SectionHeader]:
        section_names = self.sections[self.e_shstrndx]
        for i in self.sections:
            self.bf.seek(section_names.sh_offset + i.sh_name)
            name = self.read_str()
            if name == target_name:
                return i
        return None

    def parse_dynamic(self) -> None:
        sec = self.find_section(b'.dynamic')
        if sec is None:
            return
        self.bf.seek(sec.sh_offset)
        while True:
            e = DynamicEntry(self.bf, self.ptrsize, self.is_le)
            self.dynamic.append(e)
            if e.d_tag == 0:
                break

    @generate_list
    def get_section_names(self) -> T.Generator[str, None, None]:
        section_names = self.sections[self.e_shstrndx]
        for i in self.sections:
            self.bf.seek(section_names.sh_offset + i.sh_name)
            yield self.read_str().decode()

    def get_soname(self) -> T.Optional[str]:
        soname = None
        strtab = None
        for i in self.dynamic:
            if i.d_tag == DT_SONAME:
                soname = i
            if i.d_tag == DT_STRTAB:
                strtab = i
        if soname is None or strtab is None:
            return None
        self.bf.seek(strtab.val + soname.val)
        return self.read_str().decode()

    def get_entry_offset(self, entrynum: int) -> T.Optional[int]:
        sec = self.find_section(b'.dynstr')
        for i in self.dynamic:
            if i.d_tag == entrynum:
                res = sec.sh_offset + i.val
                assert isinstance(res, int)
                return res
        return None

    def get_rpath(self) -> T.Optional[str]:
        offset = self.get_entry_offset(DT_RPATH)
        if offset is None:
            return None
        self.bf.seek(offset)
        return self.read_str().decode()

    def get_runpath(self) -> T.Optional[str]:
        offset = self.get_entry_offset(DT_RUNPATH)
        if offset is None:
            return None
        self.bf.seek(offset)
        return self.read_str().decode()

    @generate_list
    def get_deps(self) -> T.Generator[str, None, None]:
        sec = self.find_section(b'.dynstr')
        for i in self.dynamic:
            if i.d_tag == DT_NEEDED:
                offset = sec.sh_offset + i.val
                self.bf.seek(offset)
                yield self.read_str().decode()

    def fix_deps(self, prefix: bytes) -> None:
        sec = self.find_section(b'.dynstr')
        deps = []
        for i in self.dynamic:
            if i.d_tag == DT_NEEDED:
                deps.append(i)
        for i in deps:
            offset = sec.sh_offset + i.val
            self.bf.seek(offset)
            name = self.read_str()
            if name.startswith(prefix):
                basename = name.rsplit(b'/', maxsplit=1)[-1]
                padding = b'\0' * (len(name) - len(basename))
                newname = basename + padding
                assert len(newname) == len(name)
                self.bf.seek(offset)
                self.bf.write(newname)

    def fix_rpath(self, fname: str, rpath_dirs_to_remove: T.Set[bytes], new_rpath: bytes) -> None:
        # The path to search for can be either rpath or runpath.
        # Fix both of them to be sure.
        self.fix_rpathtype_entry(fname, rpath_dirs_to_remove, new_rpath, DT_RPATH)
        self.fix_rpathtype_entry(fname, rpath_dirs_to_remove, new_rpath, DT_RUNPATH)

    def fix_rpathtype_entry(self, fname: str, rpath_dirs_to_remove: T.Set[bytes], new_rpath: bytes, entrynum: int) -> None:
        rp_off = self.get_entry_offset(entrynum)
        if rp_off is None:
            if self.verbose:
                print(f'File {fname!r} does not have an rpath. It should be a fully static executable.')
            return
        self.bf.seek(rp_off)

        old_rpath = self.read_str()
        # Some rpath entries may come from multiple sources.
        # Only add each one once.
        new_rpaths = OrderedSet()  # type: OrderedSet[bytes]
        if new_rpath:
            new_rpaths.update(new_rpath.split(b':'))
        if old_rpath:
            # Filter out build-only rpath entries
            # added by get_link_dep_subdirs() or
            # specified by user with build_rpath.
            for rpath_dir in old_rpath.split(b':'):
                if not (rpath_dir in rpath_dirs_to_remove or
                        rpath_dir == (b'X' * len(rpath_dir))):
                    if rpath_dir:
                        new_rpaths.add(rpath_dir)

        # Prepend user-specified new entries while preserving the ones that came from pkgconfig etc.
        new_rpath = b':'.join(new_rpaths)

        if len(old_rpath) < len(new_rpath):
            msg = "New rpath must not be longer than the old one.\n Old: {}\n New: {}".format(old_rpath.decode('utf-8'), new_rpath.decode('utf-8'))
            sys.exit(msg)
        # The linker does read-only string deduplication. If there is a
        # string that shares a suffix with the rpath, they might get
        # dedupped. This means changing the rpath string might break something
        # completely unrelated. This has already happened once with X.org.
        # Thus we want to keep this change as small as possible to minimize
        # the chance of obliterating other strings. It might still happen
        # but our behavior is identical to what chrpath does and it has
        # been in use for ages so based on that this should be rare.
        if not new_rpath:
            self.remove_rpath_entry(entrynum)
        else:
            self.bf.seek(rp_off)
            self.bf.write(new_rpath)
            self.bf.write(b'\0')

    def remove_rpath_entry(self, entrynum: int) -> None:
        sec = self.find_section(b'.dynamic')
        if sec is None:
            return None
        for (i, entry) in enumerate(self.dynamic):
            if entry.d_tag == entrynum:
                rpentry = self.dynamic[i]
                rpentry.d_tag = 0
                self.dynamic = self.dynamic[:i] + self.dynamic[i + 1:] + [rpentry]
                break
        # DT_MIPS_RLD_MAP_REL is relative to the offset of the tag. Adjust it consequently.
        for entry in self.dynamic[i:]:
            if entry.d_tag == DT_MIPS_RLD_MAP_REL:
                entry.val += 2 * (self.ptrsize // 8)
                break
        self.bf.seek(sec.sh_offset)
        for entry in self.dynamic:
            entry.write(self.bf)
        return None

def fix_elf(fname: str, rpath_dirs_to_remove: T.Set[bytes], new_rpath: T.Optional[bytes], verbose: bool = True) -> None:
    if new_rpath is not None:
        with Elf(fname, verbose) as e:
            # note: e.get_rpath() and e.get_runpath() may be useful
            e.fix_rpath(fname, rpath_dirs_to_remove, new_rpath)

def get_darwin_rpaths_to_remove(fname: str) -> T.List[str]:
    p, out, _ = Popen_safe(['otool', '-l', fname], stderr=subprocess.DEVNULL)
    if p.returncode != 0:
        raise subprocess.CalledProcessError(p.returncode, p.args, out)
    result = []
    current_cmd = 'FOOBAR'
    for line in out.split('\n'):
        line = line.strip()
        if ' ' not in line:
            continue
        key, value = line.strip().split(' ', 1)
        if key == 'cmd':
            current_cmd = value
        if key == 'path' and current_cmd == 'LC_RPATH':
            rp = value.split('(', 1)[0].strip()
            result.append(rp)
    return result

def fix_darwin(fname: str, new_rpath: str, final_path: str, install_name_mappings: T.Dict[str, str]) -> None:
    try:
        rpaths = get_darwin_rpaths_to_remove(fname)
    except subprocess.CalledProcessError:
        # Otool failed, which happens when invoked on a
        # non-executable target. Just return.
        return
    try:
        args = []
        if rpaths:
            # TODO: fix this properly, not totally clear how
            #
            # removing rpaths from binaries on macOS has tons of
            # weird edge cases. For instance, if the user provided
            # a '-Wl,-rpath' argument in LDFLAGS that happens to
            # coincide with an rpath generated from a dependency,
            # this would cause installation failures, as meson would
            # generate install_name_tool calls with two identical
            # '-delete_rpath' arguments, which install_name_tool
            # fails on. Because meson itself ensures that it never
            # adds duplicate rpaths, duplicate rpaths necessarily
            # come from user variables. The idea of using OrderedSet
            # is to remove *at most one* duplicate RPATH entry. This
            # is not optimal, as it only respects the user's choice
            # partially: if they provided a non-duplicate '-Wl,-rpath'
            # argument, it gets removed, if they provided a duplicate
            # one, it remains in the final binary. A potentially optimal
            # solution would split all user '-Wl,-rpath' arguments from
            # LDFLAGS, and later add them back with '-add_rpath'.
            for rp in OrderedSet(rpaths):
                args += ['-delete_rpath', rp]
            subprocess.check_call(['install_name_tool', fname] + args,
                                  stdout=subprocess.DEVNULL,
                                  stderr=subprocess.DEVNULL)
        args = []
        if new_rpath:
            args += ['-add_rpath', new_rpath]
        # Rewrite -install_name @rpath/libfoo.dylib to /path/to/libfoo.dylib
        if fname.endswith('dylib'):
            args += ['-id', final_path]
        if install_name_mappings:
            for old, new in install_name_mappings.items():
                args += ['-change', old, new]
        if args:
            subprocess.check_call(['install_name_tool', fname] + args,
                                  stdout=subprocess.DEVNULL,
                                  stderr=subprocess.DEVNULL)
    except Exception as err:
        raise SystemExit(err)

def fix_jar(fname: str) -> None:
    subprocess.check_call(['jar', 'xfv', fname, 'META-INF/MANIFEST.MF'])
    with open('META-INF/MANIFEST.MF', 'r+', encoding='utf-8') as f:
        lines = f.readlines()
        f.seek(0)
        for line in lines:
            if not line.startswith('Class-Path:'):
                f.write(line)
        f.truncate()
    subprocess.check_call(['jar', 'ufm', fname, 'META-INF/MANIFEST.MF'])

def fix_rpath(fname: str, rpath_dirs_to_remove: T.Set[bytes], new_rpath: T.Union[str, bytes], final_path: str, install_name_mappings: T.Dict[str, str], verbose: bool = True) -> None:
    global INSTALL_NAME_TOOL  # pylint: disable=global-statement
    # Static libraries, import libraries, debug information, headers, etc
    # never have rpaths
    # DLLs and EXE currently do not need runtime path fixing
    if fname.endswith(('.a', '.lib', '.pdb', '.h', '.hpp', '.dll', '.exe')):
        return
    try:
        if fname.endswith('.jar'):
            fix_jar(fname)
            return
        if isinstance(new_rpath, str):
            new_rpath = new_rpath.encode('utf8')
        fix_elf(fname, rpath_dirs_to_remove, new_rpath, verbose)
        return
    except SystemExit as e:
        if isinstance(e.code, int) and e.code == 0:
            pass
        else:
            raise
    # We don't look for this on import because it will do a useless PATH lookup
    # on non-mac platforms. That can be expensive on some Windows machines
    # (up to 30ms), which is significant with --only-changed. For details, see:
    # https://github.com/mesonbuild/meson/pull/6612#discussion_r378581401
    if INSTALL_NAME_TOOL is False:
        INSTALL_NAME_TOOL = bool(shutil.which('install_name_tool'))
    if INSTALL_NAME_TOOL:
        if isinstance(new_rpath, bytes):
            new_rpath = new_rpath.decode('utf8')
        fix_darwin(fname, new_rpath, final_path, install_name_mappings)