import { Utils } from '@sigmail/common';
import { HashFunction } from '@sigmail/crypto';
import { KeyDerivationFunction } from '.';
import { E_FAIL } from '../constants';
import * as Encoder from '../encoder';
import * as Hash from '../hash';
import { SigmailCryptoException } from '../SigmailCryptoException';

export interface Params {
  hash: HashFunction;
  initialCounter: number;

  /** Size, in bits, of the output. */
  outLength: number;
}

const DEFAULT_PARAMS: Params = {
  hash: Hash.SHA256,
  initialCounter: 0,
  outLength: 256
};

// Returns a byte array with hi endian representation of 32 bit integer value.
// The hi byte is handled using trunc and / to avoid any possibly negative
// conversion issue. This should work for any n < 2^54 at least.
//
// @author Kim Birchard <kbirchard@sigmahealthtech.com>
function i2osp(n: number): Uint8Array {
  return Uint8Array.of(Math.trunc(n / (1 << 24)) & 0xff, (n >> 16) & 0xff, (n >> 8) & 0xff, (n >> 0) & 0xff);
}

/**
 * Defines the primitive MGFV to generate random bytes.
 *
 * This derives any length of random bytes from a seed value and additional
 * parameters. It is an extended version of the standard MGF1 in PKCS RFPs.
 *
 * Differences from MGF1:
 * - includes the counter value at the start and end of the encrypted seed, so
 *   the entire hash must be performed for each block.
 * - include the initial counter value as a parameter, to allow deriving
 *   independent values from the same seed.
 * - allow seed to consist of a binary array plus a 32 bit unsigned
 *   number (default 0) or a second binary array.
 *
 * @author Kim Birchard <kbirchard@sigmahealthtech.com>
 */
export class MGFV extends KeyDerivationFunction {
  private readonly params: Params;

  public constructor(params?: Partial<Params>) {
    super('MGFV');

    this.params = Utils.defaults({}, params, DEFAULT_PARAMS);
  }

  /**
   * @inheritdoc
   */
  public async derive(
    seed: string | Uint8Array,
    version?: number | Uint8Array | null | undefined
  ): Promise<Uint8Array> {
    let encodedSeed: Uint8Array;
    if (Utils.isString(seed)) {
      encodedSeed = Encoder.UTF8.encode(seed);
    } else if (seed instanceof Uint8Array) {
      encodedSeed = seed;
    } else {
      throw new SigmailCryptoException(
        E_FAIL,
        'Value was expected to be of type <string> or <Uint8Array>. (Parameter name: seed)'
      );
    }

    let uint8Version = i2osp(0);
    if (!Utils.isNil(version)) {
      if (Utils.isNumber(version)) {
        if (!Utils.isFinite(version) || version < 0 || version > 2147483647) {
          throw new SigmailCryptoException(E_FAIL, 'Value must be between 0 and 2147483647. (Parameter name: version)');
        }
        uint8Version = i2osp(version);
      } else if (!(version instanceof Uint8Array)) {
        throw new SigmailCryptoException(
          E_FAIL,
          'Value was expected to be of type <number> or <Uint8Array>. (Parameter name: version)'
        );
      } else if (version.length !== 4) {
        throw new SigmailCryptoException(E_FAIL, 'Length was expected to be 4. (Parameter name: version)');
      } else {
        uint8Version = version;
      }
    }

    // determine length of array to hash, counter is inserted at start and end
    let length = 4 + encodedSeed.length + uint8Version.length + 4;

    // construct fixed parts of the value to hash, encodedSeed + uint8Version
    const uint8Seed = new Uint8Array(length);
    uint8Seed.set(encodedSeed, 4);
    uint8Seed.set(uint8Version, encodedSeed.length + 4); // version after seed

    let {
      hash: { hashLength },
      initialCounter: counter,
      outLength
    } = this.params;
    hashLength /= 8; // convert to bytes
    outLength /= 8; // convert to bytes

    let counterOffsetEnd = length - 4; // location of counter at end, last 4 bytes
    let result = new Uint8Array(outLength);

    // loop until the result buffer is filled
    for (let offset = 0; offset < outLength; offset += hashLength) {
      let uint8Counter = i2osp(counter);
      uint8Seed.set(uint8Counter, 0);
      uint8Seed.set(uint8Counter, counterOffsetEnd);

      const hashed = await this.params.hash(uint8Seed);
      if (offset + hashLength <= outLength) {
        // entire hash result is required
        result.set(hashed, offset);
      } else {
        // partial hash result is required
        result.set(hashed.subarray(0, outLength - offset), offset);
      }

      ++counter;
    }

    return result;
  }
}
