Support NNUE with varying king dimensions
authorFabian Fichter <ianfab@users.noreply.github.com>
Fri, 1 Oct 2021 10:01:02 +0000 (12:01 +0200)
committerFabian Fichter <ianfab@users.noreply.github.com>
Fri, 1 Oct 2021 11:26:09 +0000 (13:26 +0200)
This adds dedicated NNUE support for variants where kings
only have access to a limited set of squares, like Xiangqi,
or are missing entirely, like in antichess.

Closes #346.

src/nnue/features/half_ka_v2_variants.cpp
src/nnue/features/half_ka_v2_variants.h
src/nnue/nnue_feature_transformer.h
src/position.h
src/variant.h

index 692ad52..b06c7a7 100644 (file)
@@ -30,19 +30,20 @@ namespace Stockfish::Eval::NNUE::Features {
   }
 
   // Orient a square according to perspective (rotates by 180 for black)
+  // Missing kings map to index 0 (SQ_A1)
   inline Square HalfKAv2Variants::orient(Color perspective, Square s, const Position& pos) {
-    return to_variant_square(  perspective == WHITE || (pos.capture_the_flag(BLACK) & Rank8BB) ? s
-                             : flip_rank(s, pos.max_rank()), pos);
+    return s != SQ_NONE ? to_variant_square(  perspective == WHITE || (pos.capture_the_flag(BLACK) & Rank8BB) ? s
+                                            : flip_rank(s, pos.max_rank()), pos) : SQ_A1;
   }
 
   // Index of a feature for a given king position and another piece on some square
   inline IndexType HalfKAv2Variants::make_index(Color perspective, Square s, Piece pc, Square ksq, const Position& pos) {
-    return IndexType(orient(perspective, s, pos) + pos.variant()->pieceSquareIndex[perspective][pc] + ksq * pos.variant()->nnuePieceIndices);
+    return IndexType(orient(perspective, s, pos) + pos.variant()->pieceSquareIndex[perspective][pc] + pos.variant()->kingSquareIndex[ksq]);
   }
 
   // Index of a feature for a given king position and another piece on some square
   inline IndexType HalfKAv2Variants::make_index(Color perspective, int handCount, Piece pc, Square ksq, const Position& pos) {
-    return IndexType(handCount + pos.variant()->pieceHandIndex[perspective][pc] + ksq * pos.variant()->nnuePieceIndices);
+    return IndexType(handCount + pos.variant()->pieceHandIndex[perspective][pc] + pos.variant()->kingSquareIndex[ksq]);
   }
 
   // Get a list of indices for active features
@@ -51,7 +52,7 @@ namespace Stockfish::Eval::NNUE::Features {
     Color perspective,
     ValueListInserter<IndexType> active
   ) {
-    Square oriented_ksq = orient(perspective, pos.square(perspective, pos.nnue_king()), pos);
+    Square oriented_ksq = orient(perspective, pos.nnue_king_square(perspective), pos);
     Bitboard bb = pos.pieces();
     while (bb)
     {
index e55ebcb..d8815f1 100644 (file)
@@ -58,7 +58,7 @@ namespace Stockfish::Eval::NNUE::Features {
     static constexpr IndexType Dimensions = static_cast<IndexType>(SQUARE_NB) * static_cast<IndexType>(SQUARE_NB) * 19;
 
     static IndexType get_dimensions() {
-      return currentNnueVariant->nnueSquares * currentNnueVariant->nnuePieceIndices;
+      return currentNnueVariant->nnueDimensions;
     }
 
     // Maximum number of simultaneously active features.
index b4eaab3..18cb582 100644 (file)
@@ -403,7 +403,7 @@ namespace Stockfish::Eval::NNUE {
         // accumulator. Then, we update the current accumulator (pos.state()).
 
         // Gather all features to be updated.
-        const Square ksq = pos.square(perspective, pos.nnue_king());
+        const Square ksq = pos.nnue_king_square(perspective);
         IndexList removed[2], added[2];
         FeatureSet::append_changed_indices(
           ksq, next, perspective, removed[0], added[0], pos);
index 6e7c1dc..b930fa5 100644 (file)
@@ -147,6 +147,7 @@ public:
   PieceType castling_rook_piece() const;
   PieceType king_type() const;
   PieceType nnue_king() const;
+  Square nnue_king_square(Color c) const;
   bool nnue_use_pockets() const;
   bool nnue_applicable() const;
   bool checking_permitted() const;
@@ -528,6 +529,10 @@ inline PieceType Position::nnue_king() const {
   return var->nnueKing;
 }
 
+inline Square Position::nnue_king_square(Color c) const {
+  return nnue_king() ? square(c, nnue_king()) : SQ_NONE;
+}
+
 inline bool Position::nnue_use_pockets() const {
   assert(var != nullptr);
   return var->nnueUsePockets;
index db55426..852eb4a 100644 (file)
@@ -25,6 +25,7 @@
 #include <string>
 #include <functional>
 #include <sstream>
+#include <iostream>
 
 #include "types.h"
 #include "bitboard.h"
@@ -138,11 +139,11 @@ struct Variant {
   bool fastAttacks2 = true;
   std::string nnueAlias = "";
   PieceType nnueKing = KING;
-  int nnueSquares;
+  int nnueDimensions;
   bool nnueUsePockets;
-  int nnuePieceIndices;
   int pieceSquareIndex[COLOR_NB][PIECE_NB];
   int pieceHandIndex[COLOR_NB][PIECE_NB];
+  int kingSquareIndex[SQUARE_NB];
   int nnueMaxPieces;
   bool endgameEval = false;
 
@@ -205,13 +206,21 @@ struct Variant {
 
       // Initialize calculated NNUE properties
       nnueKing =  pieceTypes.find(KING) != pieceTypes.end() ? KING
-                : extinctionPieceTypes.find(COMMONER) != extinctionPieceTypes.end() ? COMMONER
+                : extinctionPieceCount == 0 && extinctionPieceTypes.find(COMMONER) != extinctionPieceTypes.end() ? COMMONER
                 : NO_PIECE_TYPE;
-      nnueSquares = (maxRank + 1) * (maxFile + 1);
+      if (nnueKing != NO_PIECE_TYPE)
+      {
+          std::string fenBoard = startFen.substr(0, startFen.find(' '));
+          // Switch NNUE from KA to A if there is no unique piece
+          if (   std::count(fenBoard.begin(), fenBoard.end(), pieceToChar[make_piece(WHITE, nnueKing)]) != 1
+              || std::count(fenBoard.begin(), fenBoard.end(), pieceToChar[make_piece(BLACK, nnueKing)]) != 1)
+              nnueKing = NO_PIECE_TYPE;
+      }
+      int nnueSquares = (maxRank + 1) * (maxFile + 1);
       nnueUsePockets = (pieceDrops && (!mustDrop || capturesToHand)) || seirawanGating;
       int nnuePockets = nnueUsePockets ? 2 * int(maxFile + 1) : 0;
-      int nnueNonDropPieceIndices = (2 * pieceTypes.size() - 1) * nnueSquares;
-      nnuePieceIndices = nnueNonDropPieceIndices + 2 * (pieceTypes.size() - 1) * nnuePockets;
+      int nnueNonDropPieceIndices = (2 * pieceTypes.size() - (nnueKing != NO_PIECE_TYPE)) * nnueSquares;
+      int nnuePieceIndices = nnueNonDropPieceIndices + 2 * (pieceTypes.size() - (nnueKing != NO_PIECE_TYPE)) * nnuePockets;
       int i = 0;
       for (PieceType pt : pieceTypes)
       {
@@ -224,6 +233,26 @@ struct Variant {
           }
           i++;
       }
+
+      // Map king squares to enumeration of actually available squares.
+      // E.g., for xiangqi map from 0-89 to 0-8.
+      // Variants might be initialized before bitboards, so do not rely on precomputed bitboards (like SquareBB).
+      int nnueKingSquare = 0;
+      if (nnueKing)
+          for (Square s = SQ_A1; s < nnueSquares; ++s)
+          {
+              Square bitboardSquare = Square(s + s / (maxFile + 1) * (FILE_MAX - maxFile));
+              if (   !mobilityRegion[WHITE][nnueKing] || !mobilityRegion[BLACK][nnueKing]
+                  || (mobilityRegion[WHITE][nnueKing] & make_bitboard(bitboardSquare))
+                  || (mobilityRegion[BLACK][nnueKing] & make_bitboard(relative_square(BLACK, bitboardSquare, maxRank))))
+              {
+                  kingSquareIndex[s] = nnueKingSquare++ * nnuePieceIndices;
+              }
+          }
+      else
+          kingSquareIndex[SQ_A1] = nnueKingSquare++ * nnuePieceIndices;
+      nnueDimensions = nnueKingSquare * nnuePieceIndices;
+
       // Determine maximum piece count
       std::istringstream ss(startFen);
       ss >> std::noskipws;