# ============================================================================
# FILE: util.py
# AUTHOR: Shougo Matsushita <Shougo.Matsu at gmail.com>
# License: MIT license
# ============================================================================

from os.path import expandvars
from pathlib import Path
from pynvim import Nvim
from pynvim.api import Buffer
import glob
import importlib.util
import re
import sys
import traceback
import typing
import unicodedata

UserContext = typing.Dict[str, typing.Any]
Candidate = typing.Dict[str, typing.Any]
Candidates = typing.List[Candidate]


def set_pattern(variable: typing.Dict[str, str],
                keys: str, pattern: typing.Any) -> None:
    for key in keys.split(','):
        variable[key] = pattern


def convert2list(expr: typing.Any) -> typing.List[typing.Any]:
    return (expr if isinstance(expr, list) else [expr])


def convert2candidates(li: typing.Any) -> Candidates:
    ret = []
    if li and isinstance(li, list):
        for x in li:
            if isinstance(x, str):
                ret.append({'word': x})
            else:
                ret.append(x)
    else:
        ret = li
    return ret


def globruntime(runtimepath: str, path: str) -> typing.List[str]:
    ret: typing.List[str] = []
    for rtp in runtimepath.split(','):
        ret += glob.glob(rtp + '/' + path)
    return ret


def import_plugin(path: str, source: str,
                  classname: str) -> typing.Optional[typing.Any]:
    """Import Deoplete plugin source class.

    If the class exists, add its directory to sys.path.
    """
    name = str(Path(path).name)[: -3]
    module_name = 'deoplete.%s.%s' % (source, name)

    spec = importlib.util.spec_from_file_location(module_name, path)
    if not spec:
        return None
    module = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(module)  # type: ignore
    cls = getattr(module, classname, None)
    if not cls:
        return None

    dirname = str(Path(path).parent)
    if dirname not in sys.path:
        sys.path.insert(0, dirname)
    return cls


def debug(vim: Nvim, expr: typing.Any) -> None:
    if hasattr(vim, 'out_write'):
        vim.out_write(f'[deoplete] {expr}\n')
    else:
        vim.call('deoplete#util#print_debug', expr)


def error(vim: Nvim, expr: typing.Any) -> None:
    if hasattr(vim, 'err_write'):
        vim.err_write(f'[deoplete] {expr}\n')
    else:
        vim.call('deoplete#util#print_error', expr)


def error_tb(vim: Nvim, msg: str) -> None:
    lines: typing.List[str] = []
    t, v, tb = sys.exc_info()
    if t and v and tb:
        lines += traceback.format_exc().splitlines()
    lines += ['%s  Use :messages / see above for error details.' % msg]
    if hasattr(vim, 'err_write'):
        vim.err_write('[deoplete] %s\n' % '\n'.join(lines))
    else:
        for line in lines:
            vim.call('deoplete#util#print_error', line)


def error_vim(vim: Nvim, msg: str) -> None:
    throwpoint = vim.eval('v:throwpoint')
    if throwpoint != '':
        error(vim, 'v:throwpoint = ' + throwpoint)
    exception = vim.eval('v:exception')
    if exception != '':
        error(vim, 'v:exception = ' + exception)
    error_tb(vim, msg)


def escape(expr: str) -> str:
    return expr.replace("'", "''")


def charpos2bytepos(encoding: str, text: str, pos: int) -> int:
    return len(bytes(text[: pos], encoding, errors='replace'))


def bytepos2charpos(encoding: str, text: str, pos: int) -> int:
    return len(bytes(text, encoding, errors='replace')[: pos].decode(
        encoding, errors='replace'))


def get_custom(custom: typing.Dict[str, typing.Any],
               source_name: str, key: str,
               default: typing.Any) -> typing.Any:
    custom_source = custom['source']
    if source_name not in custom_source:
        return get_custom(custom, '_', key, default)
    elif key in custom_source[source_name]:
        return custom_source[source_name][key]
    elif key in custom_source['_']:
        return custom_source['_'][key]
    else:
        return default


def get_syn_names(vim: Nvim) -> typing.List[str]:
    return list(vim.call('deoplete#util#get_syn_names'))


def parse_file_pattern(f: typing.Iterable[str],
                       pattern: str) -> typing.List[str]:
    p = re.compile(pattern)
    ret: typing.List[str] = []
    for li in f:
        ret += p.findall(li)
    return list(set(ret))


def parse_buffer_pattern(b: Buffer, pattern: str) -> typing.List[str]:
    return list(set(re.compile(pattern).findall('\n'.join(b))))


def fuzzy_escape(string: str, camelcase: bool) -> str:
    # Escape string for python regexp.
    p = re.sub(r'([a-zA-Z0-9_])', r'\1.*', re.escape(string))
    if camelcase and re.search(r'[A-Z]', string):
        p = re.sub(r'([a-z])', (lambda pat:
                                f'[{pat.group(1)}{pat.group(1).upper()}]'), p)
    p = re.sub(r'([a-zA-Z0-9_])\.\*', r'\1[^\1]*', p)
    return p


def load_external_module(base: str, module: str) -> None:
    current = Path(base).parent.resolve()
    module_dir = str(current.parent.joinpath(module))
    if module_dir not in sys.path:
        sys.path.insert(0, module_dir)


def truncate_skipping(string: str, max_width: int,
                      footer: str, footer_len: int) -> str:
    if not string:
        return ''
    if len(string) <= max_width / 2:
        return string
    if strwidth(string) <= max_width:
        return string

    footer += string[
            -len(truncate(string[::-1], footer_len)):]
    return truncate(string, max_width - strwidth(footer)) + footer


def truncate(string: str, max_width: int) -> str:
    if len(string) <= max_width / 2:
        return string
    if strwidth(string) <= max_width:
        return string

    width = 0
    ret = ''
    for c in string:
        wc = charwidth(c)
        if width + wc > max_width:
            break
        ret += c
        width += wc
    return ret


def strwidth(string: str) -> int:
    width = 0
    for c in string:
        width += charwidth(c)
    return width


def charwidth(c: str) -> int:
    wc = unicodedata.east_asian_width(c)
    return 2 if wc == 'F' or wc == 'W' else 1


def expand(path: str) -> str:
    if path.startswith('~'):
        try:
            path = str(Path(path).expanduser())
        except Exception:
            pass
    return expandvars(path)


def exists_path(path: str) -> bool:
    try:
        return Path(path).exists()
    except Exception:
        pass
    return False


def getlines(vim: Nvim, start: int = 1,
             end: typing.Union[int, str] = '$') -> typing.List[str]:
    if end == '$':
        end = len(vim.current.buffer)
    max_len = min([int(end) - start, 5000])
    lines: typing.List[str] = []
    current = start
    while current <= int(end):
        # Skip very long lines
        lines += [x for x in vim.call('getline', current, current + max_len)
                  if len(x) < 300]
        current += max_len + 1
    return lines


def binary_search_begin(li: typing.List[Candidates], prefix: str) -> int:
    if not li:
        return -1
    if len(li) == 1:
        word = li[0]['word']  # type: ignore
        return 0 if word.lower().startswith(prefix) else -1

    s = 0
    e = len(li)
    prefix = prefix.lower()
    while s < e:
        index = int((s + e) / 2)
        word = li[index]['word'].lower()  # type: ignore
        if word.startswith(prefix):
            if (index - 1) < 0:
                return index
            prev_word = li[index-1]['word']  # type: ignore
            if not prev_word.lower().startswith(prefix):
                return index
            e = index
        elif prefix < word:
            e = index
        else:
            s = index + 1
    return -1


def binary_search_end(li: typing.List[Candidates], prefix: str) -> int:
    if not li:
        return -1
    if len(li) == 1:
        word = li[0]['word']  # type: ignore
        return 0 if word.lower().startswith(prefix) else -1

    s = 0
    e = len(li)
    prefix = prefix.lower()
    while s < e:
        index = int((s + e) / 2)
        word = li[index]['word'].lower()  # type: ignore
        if word.startswith(prefix):
            if (index + 1) >= len(li):
                return index
            next_word = li[index+1]['word']  # type: ignore
            if not next_word.lower().startswith(prefix):
                return index
            s = index + 1
        elif prefix < word:
            e = index
        else:
            s = index + 1
    return -1


def uniq_list_dict(li: typing.List[typing.Any]) -> typing.List[typing.Any]:
    # Uniq list of dictionaries
    ret: typing.List[typing.Any] = []
    for d in li:
        if d not in ret:
            ret.append(d)
    return ret