// Copyright 2016-2018 Yubico AB
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

//go:build windows
// +build windows

package main

import (
	"fmt"
	"sync"
	"unsafe"

	log "github.com/sirupsen/logrus"
)

// #cgo CFLAGS: -DUNICODE -D_UNICODE
// #cgo LDFLAGS: -lwinusb -lsetupapi -luuid
// #include "usb_windows.h"
import "C"

var device struct {
	ctx C.PDEVICE_CONTEXT
	mtx sync.Mutex
}

type C_DWORD C.DWORD

func (e C_DWORD) Error() string {
	return fmt.Sprintf("Windows Error: 0x%x", uint(e))
}

const (
	SUCCESS                 C_DWORD = C.ERROR_SUCCESS
	ERROR_INVALID_STATE     C_DWORD = C.ERROR_INVALID_STATE
	ERROR_INVALID_HANDLE    C_DWORD = C.ERROR_INVALID_HANDLE
	ERROR_INVALID_PARAMETER C_DWORD = C.ERROR_INVALID_PARAMETER
	ERROR_OUTOFMEMORY       C_DWORD = C.ERROR_OUTOFMEMORY
	ERROR_GEN_FAILURE       C_DWORD = C.ERROR_GEN_FAILURE
	ERROR_OBJECT_NOT_FOUND  C_DWORD = C.ERROR_OBJECT_NOT_FOUND
	ERROR_NOT_SUPPORTED     C_DWORD = C.ERROR_NOT_SUPPORTED
	ERROR_SHARING_VIOLATION C_DWORD = C.ERROR_SHARING_VIOLATION
	ERROR_BAD_COMMAND       C_DWORD = C.ERROR_BAD_COMMAND
)

func winusbError(err C.DWORD) error {
	if err != C.ERROR_SUCCESS {
		return C_DWORD(err)
	}
	return nil
}

func usbopen(cid string, serial string) (err error) {
	if device.ctx != nil {
		log.WithField("Correlation-ID", cid).Debug("usb context already open")
		return nil
	}

	if serial != "" {
		cSerial := C.CString(serial)
		defer C.free(unsafe.Pointer(cSerial))

		err = winusbError(C.usbOpen(0x1050, 0x0030, cSerial, &device.ctx))
	} else {
		err = winusbError(C.usbOpen(0x1050, 0x0030, nil, &device.ctx))
	}

	if device.ctx == nil {
		err = fmt.Errorf("device not found")
	}

	return err
}

func usbclose(cid string) {
	if device.ctx != nil {
		C.usbClose(&device.ctx)
	}
}

func usbreopen(cid string, why error, serial string) (err error) {
	log.WithFields(log.Fields{
		"Correlation-ID": cid,
		"why":            why,
	}).Debug("reopening usb context")

	usbclose(cid)
	return usbopen(cid, serial)
}

func usbCheck(cid string, serial string) (err error) {
	device.mtx.Lock()
	defer device.mtx.Unlock()

	if err = usbopen(cid, serial); err != nil {
		return err
	}

	for {
		if err = winusbError(C.usbCheck(device.ctx, 0x1050, 0x0030)); err != nil {
			log.WithFields(log.Fields{
				"Correlation-ID": cid,
				"Error":          err,
			}).Debug("Couldn't check usb context")

			if err = usbreopen(cid, err, serial); err != nil {
				return err
			}
			continue
		}

		break
	}

	return nil
}

func usbwrite(buf []byte, cid string) (err error) {
	var n C.ULONG

	if err = winusbError(C.usbWrite(
		device.ctx,
		(*C.UCHAR)(unsafe.Pointer(&buf[0])),
		C.ULONG(len(buf)),
		&n)); err != nil {
		goto out
	}

out:
	log.WithFields(log.Fields{
		"Correlation-ID": cid,
		"n":              uint(n),
		"err":            err,
		"len":            len(buf),
		"buf":            buf,
	}).Debug("usb endpoint write")

	return err
}

func usbread(cid string) (buf []byte, err error) {
	var n C.ULONG

	buf = make([]byte, 8192)

	if err = winusbError(C.usbRead(
		device.ctx,
		(*C.UCHAR)(unsafe.Pointer(&buf[0])),
		C.ULONG(len(buf)),
		&n)); err != nil {
		buf = buf[:0]
		goto out
	}
	buf = buf[:n]

out:
	log.WithFields(log.Fields{
		"Correlation-ID": cid,
		"n":              uint(n),
		"err":            err,
		"len":            len(buf),
		"buf":            buf,
	}).Debug("usb endpoint read")

	return buf, err
}

func usbProxy(req []byte, cid string, serial string) (resp []byte, err error) {
	device.mtx.Lock()
	defer device.mtx.Unlock()

	if err = usbopen(cid, serial); err != nil {
		return nil, err
	}

	for i := 0; i < 2; i++ {
		if err = usbwrite(req, cid); err != nil {
			if err2 := usbreopen(cid, err, serial); err2 != nil {
				return nil, err2
			}
			continue
		}

		resp, err = usbread(cid)
		break
	}

	return resp, err
}
