From 2f49208dd18124e1126454a8fbcf6ea997cd3495 Mon Sep 17 00:00:00 2001 From: Fabian Fichter Date: Fri, 1 Oct 2021 12:01:02 +0200 Subject: [PATCH] Support NNUE with varying king dimensions This adds dedicated NNUE support for variants where kings only have access to a limited set of squares, like Xiangqi, or are missing entirely, like in antichess. Closes #346. --- src/nnue/features/half_ka_v2_variants.cpp | 11 ++++--- src/nnue/features/half_ka_v2_variants.h | 2 +- src/nnue/nnue_feature_transformer.h | 2 +- src/position.h | 5 +++ src/variant.h | 41 ++++++++++++++++++++++++---- 5 files changed, 48 insertions(+), 13 deletions(-) diff --git a/src/nnue/features/half_ka_v2_variants.cpp b/src/nnue/features/half_ka_v2_variants.cpp index 692ad52..b06c7a7 100644 --- a/src/nnue/features/half_ka_v2_variants.cpp +++ b/src/nnue/features/half_ka_v2_variants.cpp @@ -30,19 +30,20 @@ namespace Stockfish::Eval::NNUE::Features { } // Orient a square according to perspective (rotates by 180 for black) + // Missing kings map to index 0 (SQ_A1) inline Square HalfKAv2Variants::orient(Color perspective, Square s, const Position& pos) { - return to_variant_square( perspective == WHITE || (pos.capture_the_flag(BLACK) & Rank8BB) ? s - : flip_rank(s, pos.max_rank()), pos); + return s != SQ_NONE ? to_variant_square( perspective == WHITE || (pos.capture_the_flag(BLACK) & Rank8BB) ? s + : flip_rank(s, pos.max_rank()), pos) : SQ_A1; } // Index of a feature for a given king position and another piece on some square inline IndexType HalfKAv2Variants::make_index(Color perspective, Square s, Piece pc, Square ksq, const Position& pos) { - return IndexType(orient(perspective, s, pos) + pos.variant()->pieceSquareIndex[perspective][pc] + ksq * pos.variant()->nnuePieceIndices); + return IndexType(orient(perspective, s, pos) + pos.variant()->pieceSquareIndex[perspective][pc] + pos.variant()->kingSquareIndex[ksq]); } // Index of a feature for a given king position and another piece on some square inline IndexType HalfKAv2Variants::make_index(Color perspective, int handCount, Piece pc, Square ksq, const Position& pos) { - return IndexType(handCount + pos.variant()->pieceHandIndex[perspective][pc] + ksq * pos.variant()->nnuePieceIndices); + return IndexType(handCount + pos.variant()->pieceHandIndex[perspective][pc] + pos.variant()->kingSquareIndex[ksq]); } // Get a list of indices for active features @@ -51,7 +52,7 @@ namespace Stockfish::Eval::NNUE::Features { Color perspective, ValueListInserter active ) { - Square oriented_ksq = orient(perspective, pos.square(perspective, pos.nnue_king()), pos); + Square oriented_ksq = orient(perspective, pos.nnue_king_square(perspective), pos); Bitboard bb = pos.pieces(); while (bb) { diff --git a/src/nnue/features/half_ka_v2_variants.h b/src/nnue/features/half_ka_v2_variants.h index e55ebcb..d8815f1 100644 --- a/src/nnue/features/half_ka_v2_variants.h +++ b/src/nnue/features/half_ka_v2_variants.h @@ -58,7 +58,7 @@ namespace Stockfish::Eval::NNUE::Features { static constexpr IndexType Dimensions = static_cast(SQUARE_NB) * static_cast(SQUARE_NB) * 19; static IndexType get_dimensions() { - return currentNnueVariant->nnueSquares * currentNnueVariant->nnuePieceIndices; + return currentNnueVariant->nnueDimensions; } // Maximum number of simultaneously active features. diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index b4eaab3..18cb582 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -403,7 +403,7 @@ namespace Stockfish::Eval::NNUE { // accumulator. Then, we update the current accumulator (pos.state()). // Gather all features to be updated. - const Square ksq = pos.square(perspective, pos.nnue_king()); + const Square ksq = pos.nnue_king_square(perspective); IndexList removed[2], added[2]; FeatureSet::append_changed_indices( ksq, next, perspective, removed[0], added[0], pos); diff --git a/src/position.h b/src/position.h index 6e7c1dc..b930fa5 100644 --- a/src/position.h +++ b/src/position.h @@ -147,6 +147,7 @@ public: PieceType castling_rook_piece() const; PieceType king_type() const; PieceType nnue_king() const; + Square nnue_king_square(Color c) const; bool nnue_use_pockets() const; bool nnue_applicable() const; bool checking_permitted() const; @@ -528,6 +529,10 @@ inline PieceType Position::nnue_king() const { return var->nnueKing; } +inline Square Position::nnue_king_square(Color c) const { + return nnue_king() ? square(c, nnue_king()) : SQ_NONE; +} + inline bool Position::nnue_use_pockets() const { assert(var != nullptr); return var->nnueUsePockets; diff --git a/src/variant.h b/src/variant.h index db55426..852eb4a 100644 --- a/src/variant.h +++ b/src/variant.h @@ -25,6 +25,7 @@ #include #include #include +#include #include "types.h" #include "bitboard.h" @@ -138,11 +139,11 @@ struct Variant { bool fastAttacks2 = true; std::string nnueAlias = ""; PieceType nnueKing = KING; - int nnueSquares; + int nnueDimensions; bool nnueUsePockets; - int nnuePieceIndices; int pieceSquareIndex[COLOR_NB][PIECE_NB]; int pieceHandIndex[COLOR_NB][PIECE_NB]; + int kingSquareIndex[SQUARE_NB]; int nnueMaxPieces; bool endgameEval = false; @@ -205,13 +206,21 @@ struct Variant { // Initialize calculated NNUE properties nnueKing = pieceTypes.find(KING) != pieceTypes.end() ? KING - : extinctionPieceTypes.find(COMMONER) != extinctionPieceTypes.end() ? COMMONER + : extinctionPieceCount == 0 && extinctionPieceTypes.find(COMMONER) != extinctionPieceTypes.end() ? COMMONER : NO_PIECE_TYPE; - nnueSquares = (maxRank + 1) * (maxFile + 1); + if (nnueKing != NO_PIECE_TYPE) + { + std::string fenBoard = startFen.substr(0, startFen.find(' ')); + // Switch NNUE from KA to A if there is no unique piece + if ( std::count(fenBoard.begin(), fenBoard.end(), pieceToChar[make_piece(WHITE, nnueKing)]) != 1 + || std::count(fenBoard.begin(), fenBoard.end(), pieceToChar[make_piece(BLACK, nnueKing)]) != 1) + nnueKing = NO_PIECE_TYPE; + } + int nnueSquares = (maxRank + 1) * (maxFile + 1); nnueUsePockets = (pieceDrops && (!mustDrop || capturesToHand)) || seirawanGating; int nnuePockets = nnueUsePockets ? 2 * int(maxFile + 1) : 0; - int nnueNonDropPieceIndices = (2 * pieceTypes.size() - 1) * nnueSquares; - nnuePieceIndices = nnueNonDropPieceIndices + 2 * (pieceTypes.size() - 1) * nnuePockets; + int nnueNonDropPieceIndices = (2 * pieceTypes.size() - (nnueKing != NO_PIECE_TYPE)) * nnueSquares; + int nnuePieceIndices = nnueNonDropPieceIndices + 2 * (pieceTypes.size() - (nnueKing != NO_PIECE_TYPE)) * nnuePockets; int i = 0; for (PieceType pt : pieceTypes) { @@ -224,6 +233,26 @@ struct Variant { } i++; } + + // Map king squares to enumeration of actually available squares. + // E.g., for xiangqi map from 0-89 to 0-8. + // Variants might be initialized before bitboards, so do not rely on precomputed bitboards (like SquareBB). + int nnueKingSquare = 0; + if (nnueKing) + for (Square s = SQ_A1; s < nnueSquares; ++s) + { + Square bitboardSquare = Square(s + s / (maxFile + 1) * (FILE_MAX - maxFile)); + if ( !mobilityRegion[WHITE][nnueKing] || !mobilityRegion[BLACK][nnueKing] + || (mobilityRegion[WHITE][nnueKing] & make_bitboard(bitboardSquare)) + || (mobilityRegion[BLACK][nnueKing] & make_bitboard(relative_square(BLACK, bitboardSquare, maxRank)))) + { + kingSquareIndex[s] = nnueKingSquare++ * nnuePieceIndices; + } + } + else + kingSquareIndex[SQ_A1] = nnueKingSquare++ * nnuePieceIndices; + nnueDimensions = nnueKingSquare * nnuePieceIndices; + // Determine maximum piece count std::istringstream ss(startFen); ss >> std::noskipws; -- 1.7.0.4