Support atomic NNUE
authorFabian Fichter <ianfab@users.noreply.github.com>
Fri, 26 Feb 2021 17:33:40 +0000 (18:33 +0100)
committerFabian Fichter <ianfab@users.noreply.github.com>
Fri, 26 Feb 2021 17:33:40 +0000 (18:33 +0100)
src/nnue/features/half_kp.cpp
src/nnue/nnue_feature_transformer.h
src/position.cpp
src/position.h
src/types.h
src/variant.h

index 4b90f47..508b024 100644 (file)
@@ -44,8 +44,8 @@ namespace Eval::NNUE::Features {
   void HalfKP<AssociatedKing>::AppendActiveIndices(
       const Position& pos, Color perspective, IndexList* active) {
 
-    Square ksq = orient(pos, perspective, pos.square<KING>(perspective));
-    Bitboard bb = pos.pieces() & ~pos.pieces(KING);
+    Square ksq = orient(pos, perspective, pos.square(perspective, pos.nnue_king()));
+    Bitboard bb = pos.pieces() & ~pos.pieces(pos.nnue_king());
     while (bb) {
       Square s = pop_lsb(&bb);
       active->push_back(make_index(pos, perspective, s, pos.piece_on(s), ksq));
@@ -58,10 +58,10 @@ namespace Eval::NNUE::Features {
       const Position& pos, const DirtyPiece& dp, Color perspective,
       IndexList* removed, IndexList* added) {
 
-    Square ksq = orient(pos, perspective, pos.square<KING>(perspective));
+    Square ksq = orient(pos, perspective, pos.square(perspective, pos.nnue_king()));
     for (int i = 0; i < dp.dirty_num; ++i) {
       Piece pc = dp.piece[i];
-      if (type_of(pc) == KING) continue;
+      if (type_of(pc) == pos.nnue_king()) continue;
       if (dp.from[i] != SQ_NONE)
         removed->push_back(make_index(pos, perspective, dp.from[i], pc, ksq));
       if (dp.to[i] != SQ_NONE)
index 2641321..38395c0 100644 (file)
@@ -256,7 +256,7 @@ namespace Eval::NNUE {
         static_assert(std::is_same_v<RawFeatures::SortedTriggerSet,
               Features::CompileTimeList<Features::TriggerEvent, Features::TriggerEvent::kFriendKingMoved>>,
               "Current code assumes that only kFriendlyKingMoved refresh trigger is being used.");
-        if (   dp.piece[0] == make_piece(c, KING)
+        if (   dp.piece[0] == make_piece(c, pos.nnue_king())
             || (gain -= dp.dirty_num + 1) < 0)
           break;
         next = st;
index f01aec4..065f34e 100644 (file)
@@ -1624,6 +1624,14 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) {
           if (type_of(bpc) != PAWN)
               st->nonPawnMaterial[bc] -= PieceValue[MG][bpc];
 
+          if (Eval::useNNUE)
+          {
+              dp.piece[dp.dirty_num] = bpc;
+              dp.from[dp.dirty_num] = bsq;
+              dp.to[dp.dirty_num] = SQ_NONE;
+              dp.dirty_num++;
+          }
+
           // Update board and piece lists
           // In order to not have to store the values of both board and unpromotedBoard,
           // demote promoted pieces, but keep promoted pawns as promoted,
index 5ac26ef..e439ad7 100644 (file)
@@ -141,6 +141,7 @@ public:
   PieceType castling_king_piece() const;
   PieceType castling_rook_piece() const;
   PieceType king_type() const;
+  PieceType nnue_king() const;
   bool checking_permitted() const;
   bool drop_checks() const;
   bool must_capture() const;
@@ -501,6 +502,11 @@ inline PieceType Position::king_type() const {
   return var->kingType;
 }
 
+inline PieceType Position::nnue_king() const {
+  assert(var != nullptr);
+  return var->nnueKing;
+}
+
 inline bool Position::checking_permitted() const {
   assert(var != nullptr);
   return var->checking;
index 715bba8..04ff04c 100644 (file)
@@ -566,11 +566,11 @@ struct DirtyPiece {
   // Max 3 pieces can change in one move. A promotion with capture moves
   // both the pawn and the captured piece to SQ_NONE and the piece promoted
   // to from SQ_NONE to the capture square.
-  Piece piece[3];
+  Piece piece[12];
 
   // From and to squares, which may be SQ_NONE
-  Square from[3];
-  Square to[3];
+  Square from[12];
+  Square to[12];
 };
 
 /// Score enum stores a middlegame and an endgame value in a single integer (enum).
index 818db65..a68dac9 100644 (file)
@@ -131,6 +131,7 @@ struct Variant {
   // Derived properties
   bool fastAttacks = true;
   bool fastAttacks2 = true;
+  PieceType nnueKing = KING;
 
   void add_piece(PieceType pt, char c, char c2 = ' ') {
       pieceToChar[make_piece(WHITE, pt)] = toupper(c);
@@ -175,6 +176,9 @@ struct Variant {
                                 })
                     && !cambodianMoves
                     && !diagonalLines;
+      nnueKing =  pieceTypes.find(KING) != pieceTypes.end() ? KING
+                : extinctionPieceTypes.find(COMMONER) != extinctionPieceTypes.end() ? COMMONER
+                : NO_PIECE_TYPE;
       return this;
   }
 };