import datetime
import uuid
import time
from typing import Optional, List

import requests
from urllib.parse import urlencode
import base64
import hashlib
from jose import jwt


class MSPAPIClient:
    def __init__(self, client_id: str = None, client_secret: str = None, host: str = None, api_version: str = None,):
        self.client_id = client_id
        self.client_secret = client_secret
        self.token = None
        self.token_expiry = None
        self.token_buffer = 60
        self.host = host
        self.api_version = api_version

    def should_refresh_token(self):
        if not self.token:
            return True

        return time.time() + self.token_buffer > self.token_expiry

    def generate_signature(self, request_id: str, timestamp: str, request_string: str = None):
        if request_string:
            signature_string = f'{request_id}{self.client_id}{timestamp}{request_string}' \
                               f'{self.client_secret}'
        else:
            signature_string = f'{request_id}{self.client_id}{timestamp}{self.client_secret}'
        signature_bytes = signature_string.encode('utf-8')
        signature_base64_bytes = base64.b64encode(signature_bytes)
        signature_hash = hashlib.sha256(signature_base64_bytes).hexdigest()
        return signature_hash

    def headers(self, request_string: str = None, auth: bool = False):
        request_id = str(uuid.uuid4())
        timestamp = datetime.datetime.utcnow().isoformat()
        headers = {
            'x-av-req-id': request_id,
            'x-av-app-id': self.client_id,
            'x-av-date': timestamp,
            'x-av-sig': self.generate_signature(request_id, timestamp, request_string)
        }
        if not auth:
            headers['x-av-token'] = self.get_token()
        return headers

    def get_token(self):
        if not self.should_refresh_token():
            return self.token

        res = requests.get(f'https://{self.host}/{self.api_version}/auth', headers=self.headers(auth=True))
        res.raise_for_status()
        self.token = res.content.decode('utf-8')
        decoded_token = jwt.decode(self.token, self.public_key())
        self.token_expiry = decoded_token['exp']
        return self.token

    def public_key(self):
        res = requests.get(f'https://{self.host}/{self.api_version}/public_key')
        res.raise_for_status()
        return res.json()

    def call_api(self, method: str, endpoint: str, params: dict = None, body: dict = None):
        request_string = f'/{self.api_version}/{endpoint}'
        if params:
            request_string += f'?{urlencode(params, doseq=True)}'
        headers = self.headers(request_string)
        res = requests.request(method, f'https://{self.host}/{self.api_version}/{endpoint}', headers=headers,
                               params=params, json=body)
        try:
            res.raise_for_status()
        except requests.exceptions.HTTPError as e:
            print(f'request exception: status_code[{e.response.status_code}] response[{e.response.content}]')
            raise e
        return res.json()

    @staticmethod
    def strip_none(paylod):
        return {k: v for k, v in paylod.items() if v is not None}

    def list_tenants(self, msp_id: int = None, scroll_id: str = None):
        body = {'scrollId': scroll_id, 'MSPId': msp_id}
        res = self.call_api('GET', 'msp/tenants', body={'requestData': self.strip_none(body)})
        return res

    def create_tenant(self, admin_email: str, admin_name: str, tenant_name: str, admin_phone_number: str,
                             company_name: str, tenant_region: str, msp_id: int):
        request_body = {
            'adminEmail': admin_email,
            'tenantName': tenant_name,
            'adminName': admin_name,
            'phone': admin_phone_number,
            'companyName': company_name,
            'tenantRegion': tenant_region,
            'MSPId': msp_id,
        }
        res = self.call_api('POST', 'msp/tenants',  body={'requestData': request_body})
        return res

    def tenant_details(self, tenant_id: int):
        res = self.call_api('GET', f'msp/tenants/{tenant_id}')
        return res

    def delete_tenant(self, tenant_id: str):
        res = self.call_api('DELETE', f'msp/tenants/{tenant_id}')
        return res

    def list_msp_packages(self):
        res = self.call_api('GET', 'msp/licenses')
        return res

    def list_addons(self):
        res = self.call_api('GET', 'msp/addons')
        return res

    def license_tenant(self, tenant_id: int, license_code_name: str, addon_id_list: list, max_seats: int):
        request_body = {
            'licenseCodeName': license_code_name,
            'addonIdList': addon_id_list,
            'maxLicensedUsers': max_seats
        }
        res = self.call_api('POST', f'msp/tenants/{tenant_id}/license', body={'requestData': request_body})
        return res

    def list_monthly_usages(self, year: int, month: int, msp_ids: Optional[List[int]] = None,
                            scroll_id: str = None):
        params = {'year': year, 'month': month}
        body = {}
        if msp_ids:
            params['msp_ids'] = msp_ids
        if scroll_id:
            body['scrollId'] = scroll_id
        res = self.call_api('GET', 'msp/usage', params=params, body={'requestData': body})
        return res

    def list_daily_usages(self, year: int, month: int, day: int, msp_ids: Optional[List[int]] = None,
                          scroll_id: str = None):
        params = {'year': year, 'month': month, 'day': day, 'msp_ids': msp_ids}
        body = {'scrollId': scroll_id}
        res = self.call_api('GET', 'msp/usage/day', params=self.strip_none(params),
                            body={'requestData': self.strip_none(body)})
        return res

    def list_msps(self):
        res = self.call_api('GET', 'msp/msp-partners')
        return res

    def create_msp(self, msp_name):
        request_body = {'name': msp_name}
        res = self.call_api('POST', 'msp/msp-partners',  body={'requestData': request_body})
        return res

    def delete_msp(self, msp_id: int):
        res = self.call_api('DELETE', f'msp/msp-partners/{msp_id}')
        return res

    def list_msp_users(self, msp_id: int = None, scroll_id: str = None):
        body = {'MSPId': msp_id, 'scrollId': scroll_id}
        res = self.call_api('GET', 'msp/users', body={'requestData': self.strip_none(body)})
        return res

    def describe_user(self, user_id):
        res = self.call_api('GET', f'msp/users/{user_id}')
        return res

    def create_user(self, msp_id: int, first_name: str, last_name: str, email: str, role: str,
                    direct_login: bool, saml_login: bool, view_private_data: bool,
                    receive_weekly_reports: bool, send_alerts: bool):
        request_body = {
            'MSPId': msp_id,
            'firstName': first_name,
            'lastName': last_name,
            'email': email,
            'role': role,
            'directLogin': direct_login,
            'samlLogin': saml_login,
            'viewPrivateData': view_private_data,
            'receiveWeeklyReports': receive_weekly_reports,
            'sendAlerts': send_alerts,
        }
        res = self.call_api('POST', 'msp/users', body={'requestData': request_body})
        return res

    def edit_user(self, user_id: int, user_data: dict):
        res = self.call_api('PUT', f'msp/users/{user_id}', body={'requestData': user_data})
        return res

    def delete_user(self, user_id):
        res = self.call_api('DELETE', f'msp/users/{user_id}')
        return res
