From 4ea84520ebe8e2e86dab456d6519742560bc243d Mon Sep 17 00:00:00 2001 From: Fabian Fichter Date: Fri, 26 Feb 2021 18:33:40 +0100 Subject: [PATCH] Support atomic NNUE --- src/nnue/features/half_kp.cpp | 8 ++++---- src/nnue/nnue_feature_transformer.h | 2 +- src/position.cpp | 8 ++++++++ src/position.h | 6 ++++++ src/types.h | 6 +++--- src/variant.h | 4 ++++ 6 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/nnue/features/half_kp.cpp b/src/nnue/features/half_kp.cpp index 4b90f47..508b024 100644 --- a/src/nnue/features/half_kp.cpp +++ b/src/nnue/features/half_kp.cpp @@ -44,8 +44,8 @@ namespace Eval::NNUE::Features { void HalfKP::AppendActiveIndices( const Position& pos, Color perspective, IndexList* active) { - Square ksq = orient(pos, perspective, pos.square(perspective)); - Bitboard bb = pos.pieces() & ~pos.pieces(KING); + Square ksq = orient(pos, perspective, pos.square(perspective, pos.nnue_king())); + Bitboard bb = pos.pieces() & ~pos.pieces(pos.nnue_king()); while (bb) { Square s = pop_lsb(&bb); active->push_back(make_index(pos, perspective, s, pos.piece_on(s), ksq)); @@ -58,10 +58,10 @@ namespace Eval::NNUE::Features { const Position& pos, const DirtyPiece& dp, Color perspective, IndexList* removed, IndexList* added) { - Square ksq = orient(pos, perspective, pos.square(perspective)); + Square ksq = orient(pos, perspective, pos.square(perspective, pos.nnue_king())); for (int i = 0; i < dp.dirty_num; ++i) { Piece pc = dp.piece[i]; - if (type_of(pc) == KING) continue; + if (type_of(pc) == pos.nnue_king()) continue; if (dp.from[i] != SQ_NONE) removed->push_back(make_index(pos, perspective, dp.from[i], pc, ksq)); if (dp.to[i] != SQ_NONE) diff --git a/src/nnue/nnue_feature_transformer.h b/src/nnue/nnue_feature_transformer.h index 2641321..38395c0 100644 --- a/src/nnue/nnue_feature_transformer.h +++ b/src/nnue/nnue_feature_transformer.h @@ -256,7 +256,7 @@ namespace Eval::NNUE { static_assert(std::is_same_v>, "Current code assumes that only kFriendlyKingMoved refresh trigger is being used."); - if ( dp.piece[0] == make_piece(c, KING) + if ( dp.piece[0] == make_piece(c, pos.nnue_king()) || (gain -= dp.dirty_num + 1) < 0) break; next = st; diff --git a/src/position.cpp b/src/position.cpp index f01aec4..065f34e 100644 --- a/src/position.cpp +++ b/src/position.cpp @@ -1624,6 +1624,14 @@ void Position::do_move(Move m, StateInfo& newSt, bool givesCheck) { if (type_of(bpc) != PAWN) st->nonPawnMaterial[bc] -= PieceValue[MG][bpc]; + if (Eval::useNNUE) + { + dp.piece[dp.dirty_num] = bpc; + dp.from[dp.dirty_num] = bsq; + dp.to[dp.dirty_num] = SQ_NONE; + dp.dirty_num++; + } + // Update board and piece lists // In order to not have to store the values of both board and unpromotedBoard, // demote promoted pieces, but keep promoted pawns as promoted, diff --git a/src/position.h b/src/position.h index 5ac26ef..e439ad7 100644 --- a/src/position.h +++ b/src/position.h @@ -141,6 +141,7 @@ public: PieceType castling_king_piece() const; PieceType castling_rook_piece() const; PieceType king_type() const; + PieceType nnue_king() const; bool checking_permitted() const; bool drop_checks() const; bool must_capture() const; @@ -501,6 +502,11 @@ inline PieceType Position::king_type() const { return var->kingType; } +inline PieceType Position::nnue_king() const { + assert(var != nullptr); + return var->nnueKing; +} + inline bool Position::checking_permitted() const { assert(var != nullptr); return var->checking; diff --git a/src/types.h b/src/types.h index 715bba8..04ff04c 100644 --- a/src/types.h +++ b/src/types.h @@ -566,11 +566,11 @@ struct DirtyPiece { // Max 3 pieces can change in one move. A promotion with capture moves // both the pawn and the captured piece to SQ_NONE and the piece promoted // to from SQ_NONE to the capture square. - Piece piece[3]; + Piece piece[12]; // From and to squares, which may be SQ_NONE - Square from[3]; - Square to[3]; + Square from[12]; + Square to[12]; }; /// Score enum stores a middlegame and an endgame value in a single integer (enum). diff --git a/src/variant.h b/src/variant.h index 818db65..a68dac9 100644 --- a/src/variant.h +++ b/src/variant.h @@ -131,6 +131,7 @@ struct Variant { // Derived properties bool fastAttacks = true; bool fastAttacks2 = true; + PieceType nnueKing = KING; void add_piece(PieceType pt, char c, char c2 = ' ') { pieceToChar[make_piece(WHITE, pt)] = toupper(c); @@ -175,6 +176,9 @@ struct Variant { }) && !cambodianMoves && !diagonalLines; + nnueKing = pieceTypes.find(KING) != pieceTypes.end() ? KING + : extinctionPieceTypes.find(COMMONER) != extinctionPieceTypes.end() ? COMMONER + : NO_PIECE_TYPE; return this; } }; -- 1.7.0.4