#include "quantum.h"
#include "action_pseudo_lut.h"

static uint8_t send_key_shift_bit[SHIFT_BIT_SIZE];

/*
 * Pseudo layout action.
 * This action converts a keycode in order to output the character according to the keymap you specified
 * still your keyboard layout recognized wrongly on your OS.
 * Memo: Using other layer keymap to get keycode
 */
void action_pseudo_lut(keyrecord_t *record, uint8_t base_keymap_id, const uint16_t (*keymap)[2]) {
    uint8_t prev_shift;
    uint16_t keycode;
    uint16_t pseudo_keycode;

    /* get keycode from keymap you specified */
    keycode = keymap_key_to_keycode(base_keymap_id, record->event.key);

    prev_shift = keyboard_report->mods & (MOD_BIT(KC_LSFT) | MOD_BIT(KC_RSFT));

    if (record->event.pressed) {
        /* when magic commands entered, keycode does not converted */
        if (IS_COMMAND()) {
            if (prev_shift) {
                add_shift_bit(keycode);
            }
            register_code(keycode);
            return;
        }

        if (prev_shift) {
            pseudo_keycode = convert_keycode(keymap, keycode, true);
            dprintf("pressed: %02X, converted: %04X\n", keycode, pseudo_keycode);
            add_shift_bit(keycode);

            if (IS_LSFT(pseudo_keycode)) {
                register_code(QK_LSFT ^ pseudo_keycode);
            } else {
                /* delete shift mod temporarily */
                del_mods(prev_shift);
                send_keyboard_report();
                register_code(pseudo_keycode);
                add_mods(prev_shift);
                send_keyboard_report();
            }
        } else {
            pseudo_keycode = convert_keycode(keymap, keycode, false);
            dprintf("pressed: %02X, converted: %04X\n", keycode, pseudo_keycode);

            if (IS_LSFT(pseudo_keycode)) {
                add_weak_mods(MOD_BIT(KC_LSFT));
                send_keyboard_report();
                register_code(QK_LSFT ^ pseudo_keycode);
                /* on Windows, prevent key repeat to avoid unintended output */
                unregister_code(QK_LSFT ^ pseudo_keycode);
                del_weak_mods(MOD_BIT(KC_LSFT));
                send_keyboard_report();
            } else {
                register_code(pseudo_keycode);
            }
        }
    } else {
        if (get_shift_bit(keycode)) {
            del_shift_bit(keycode);
            pseudo_keycode = convert_keycode(keymap, keycode, true);
        } else {
            pseudo_keycode = convert_keycode(keymap, keycode, false);
        }
        dprintf("released: %02X, converted: %04X\n", keycode, pseudo_keycode);

        if (IS_LSFT(pseudo_keycode)) {
            unregister_code(QK_LSFT ^ pseudo_keycode);
        } else {
            unregister_code(pseudo_keycode);
        }
    }
}

uint16_t convert_keycode(const uint16_t (*keymap)[2], uint16_t keycode, bool shift_modded)
{
    uint16_t pseudo_keycode;

    switch (keycode) {
        case KC_A ... KC_CAPSLOCK:
#if defined(__AVR__)
            if (shift_modded) {
                pseudo_keycode = pgm_read_word(&keymap[keycode][1]);
            } else {
                pseudo_keycode = pgm_read_word(&keymap[keycode][0]);
            }
#else
            if (shift_modded) {
                pseudo_keycode = keymap[keycode][1];
            } else {
                pseudo_keycode = keymap[keycode][0];
            }
#endif
            /* if undefined, use got keycode as it is */
            if (pseudo_keycode == 0x00) {
                if (shift_modded) {
                    pseudo_keycode = S(keycode);
                } else {
                    pseudo_keycode = keycode;
                }
            }
            break;
        default:
            if (shift_modded) {
                pseudo_keycode = S(keycode);
            } else {
                pseudo_keycode = keycode;
            }
            break;
    }
    return pseudo_keycode;
}

uint8_t get_shift_bit(uint16_t keycode) {
    if ((keycode >> 3) < SHIFT_BIT_SIZE) {
        return send_key_shift_bit[keycode >> 3] & (1 << (keycode & 7));
    } else {
        dprintf("get_shift_bit: Can't get shift bit. keycode: %02X\n", keycode);
        return 0;
    }
}

void add_shift_bit(uint16_t keycode) {
    if ((keycode >> 3) < SHIFT_BIT_SIZE) {
        send_key_shift_bit[keycode >> 3] |= (1 << (keycode & 7));
    } else {
        dprintf("add_shift_bit: Can't add shift bit. keycode: %02X\n", keycode);
    }
}

void del_shift_bit(uint16_t keycode) {
    if ((keycode >> 3) < SHIFT_BIT_SIZE) {
        send_key_shift_bit[keycode >> 3] &= ~(1 << (keycode & 7));
    } else {
        dprintf("del_shift_bit: Can't delete shift bit. keycode: %02X\n", keycode);
    }
}