鉄分は大事。(特にヘム鉄)

こっち→ https://brookbach.com

RustでAIその2 続き

概要

前回

brookbach.hatenablog.com

の続き,だいぶ間があいちゃったのでとりあえず雑だけど投稿してみる

数手先読み

  • 数手先読みする
  • SEARCH_DEPTHを定義
  • 1.8.0ではRefCellにOrdが定義されていないのでRefNode(RefCell)として定義
  • RefNodeはタプルなので,borrow, borrow_mut,into_innerを呼び出されたときに中身をそのまま呼ぶ関数を定義
#[derive(Eq, PartialEq, Debug)]
struct RefNode(RefCell<Node>)
impl PartialOrd for RefNode {
    fn partial_cmp(&self, other: &RefNode) -> Option<Ordering> {
        self.0.borrow().partial_cmp(&*other.0.borrow())
    }

    fn lt(&self, other: &RefNode) -> bool {
        *self.0.borrow() < *other.0.borrow()
    }

    #[inline]
    fn le(&self, other: &RefNode) -> bool {
        *self.0.borrow() <= *other.0.borrow()
    }

    #[inline]
    fn gt(&self, other: &RefNode) -> bool {
        *self.0.borrow() > *other.0.borrow()
    }

    #[inline]
    fn ge(&self, other: &RefNode) -> bool {
        *self.0.borrow() >= *other.0.borrow()
    }
}

impl Ord for RefNode {
    #[inline]
    fn cmp(&self, other: &RefNode) -> Ordering {
        self.0.borrow().cmp(&*other.0.borrow())
    }
}

impl RefNode {
    fn borrow(&self) -> Ref<Node> {
        self.0.borrow()
    }

    fn borrow_mut(&self) -> RefMut<Node> {
        self.0.borrow_mut()
    }

    fn into_inner(self) -> Node {
        self.0.into_inner()
    }
}

定義できたら,元記事を参考に幅優先探索を実装

探索後,一番スコアが高いノードの親ノードまでトラバース この時点の親ノードは,Rc(RefNode)となっているので,Rc::try_unwrapでRcをはがしてからRefNodeの中身をinto_innerで取り出す

search関数はこんな感じ

    fn search(&mut self) -> Node {
        let mut heap = BinaryHeap::new();

        heap.push(self.head.clone().unwrap());
        for _ in 0..SEARCH_DEPTH {
            let mut tmp = heap.clone();
            heap.clear();

            while let Some(current_node) = tmp.pop() {
                for i in 0..DX.len() {
                    let next_node = Node::new(Some(current_node.clone()));
                    next_node.borrow_mut().players[self.p].x1 += DX[i];
                    next_node.borrow_mut().players[self.p].y1 += DY[i];
                    next_node.borrow_mut().output = OUTPUT[i].to_string();
                    let score = next_node.borrow_mut().eval_player(self.p);
                    next_node.borrow_mut().score += score;
                    heap.push(next_node.clone());
                }
            }
        }

        // the top of heap is the best result after search
        let mut node = heap.pop().unwrap();
        // traverse until the next of head
        loop {
            let next_node;
            match node.borrow_mut().parent.take() {
                Some(parent) => {
                    next_node = parent;
                }
                None => { panic!(); }
            }
            if next_node == self.head.clone().unwrap() {
                break;
            }
            node = next_node;
        }
        heap.clear();
        Rc::try_unwrap(node).ok().unwrap().into_inner()
    }

ビームサーチ版

ヒープをビーム幅に縮小する以外は同じ

    fn search(&mut self) -> Node {
        let mut heap = BinaryHeap::new();

        heap.push(self.head.clone().unwrap());
        for _ in 0..SEARCH_DEPTH {
            let mut tmp = BinaryHeap::new();
            for _ in 0..BEAM_WIDTH {
                if let Some(data) = heap.pop() {
                    tmp.push(data);
                } else {
                    break;
                }
            }
            heap.clear();

            while let Some(current_node) = tmp.pop() {
                for i in 0..DX.len() {
                    let next_node = Node::new(Some(current_node.clone()));
                    next_node.borrow_mut().players[self.p].x1 += DX[i];
                    next_node.borrow_mut().players[self.p].y1 += DY[i];
                    next_node.borrow_mut().output = OUTPUT[i].to_string();
                    let score = next_node.borrow_mut().eval_player(self.p);
                    next_node.borrow_mut().score += score;
                    heap.push(next_node.clone());
                }
            }
        }

        // the top of heap is the best result after search
        let mut node = heap.pop().unwrap();
        // traverse until the next of head
        loop {
            let next_node;
            match node.borrow_mut().parent.take() {
                Some(parent) => {
                    next_node = parent;
                }
                None => { panic!(); }
            }
            if next_node == self.head.clone().unwrap() {
                break;
            }
            node = next_node;
        }
        heap.clear();
        Rc::try_unwrap(node).ok().unwrap().into_inner()
    }

最終版

use std::io;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::rc::Rc;
use std::cell::{Ref, RefMut, RefCell};

macro_rules! print_err {
    ($($arg:tt)*) => (
        {
            use std::io::Write;
            writeln!(&mut ::std::io::stderr(), $($arg)*).ok();
        }
    )
}

macro_rules! parse_input {
    ($x:expr, $t:ident) => ($x.trim().parse::<$t>().unwrap());
}

const DX: [i32; 4] = [1, 0, -1, 0];
const DY: [i32; 4] = [0, 1, 0, -1];
const OUTPUT: [&'static str; 4] = ["RIGHT", "DOWN", "LEFT", "UP"];

// const MAX_PLAYER_NUM: i32 = 4;
const COL: i32 = 30;
const ROW: i32 = 20;

const SEARCH_DEPTH: usize = 100;
const BEAM_WIDTH: usize = 20;

#[derive(Eq, PartialEq, Debug)]
struct RefNode(RefCell<Node>);

type Link = Option<Rc<RefNode>>;

struct Game {
    head: Link,
    n: usize,
    p: usize,
}

#[derive(Eq, PartialEq, Debug, Clone)]
struct Node {
    score: i32,
    output: String,
    parent: Link,
    players: Vec<Player>,
}

#[derive(Eq, PartialEq, Debug, Clone)]
struct Player {
    x0: i32,
    y0: i32,
    x1: i32,
    y1: i32,
    locked_field: [[bool; ROW as usize]; COL as usize],
}

impl Ord for Node {
    fn cmp(&self, other: &Self) -> Ordering {
        self.score.cmp(&other.score)
    }
}

impl PartialOrd for Node {
    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
        Some(self.score.cmp(&other.score))
    }
}

impl PartialOrd for RefNode {
    fn partial_cmp(&self, other: &RefNode) -> Option<Ordering> {
        self.0.borrow().partial_cmp(&*other.0.borrow())
    }

    fn lt(&self, other: &RefNode) -> bool {
        *self.0.borrow() < *other.0.borrow()
    }

    #[inline]
    fn le(&self, other: &RefNode) -> bool {
        *self.0.borrow() <= *other.0.borrow()
    }

    #[inline]
    fn gt(&self, other: &RefNode) -> bool {
        *self.0.borrow() > *other.0.borrow()
    }

    #[inline]
    fn ge(&self, other: &RefNode) -> bool {
        *self.0.borrow() >= *other.0.borrow()
    }
}

impl Ord for RefNode {
    #[inline]
    fn cmp(&self, other: &RefNode) -> Ordering {
        self.0.borrow().cmp(&*other.0.borrow())
    }
}

impl RefNode {
    fn borrow(&self) -> Ref<Node> {
        self.0.borrow()
    }

    fn borrow_mut(&self) -> RefMut<Node> {
        self.0.borrow_mut()
    }

    fn into_inner(self) -> Node {
        self.0.into_inner()
    }
}

impl Game {
    fn new() -> Self {
        Game {
            head: None,
            n: 0,
            p: 0,
        }
    }

    fn input(&mut self) {
        let new_node = Node::new(self.head.clone());

        // input game data
        let mut input_line = String::new();
        io::stdin().read_line(&mut input_line).unwrap();
        if self.head.is_none() {
            // initialize game data & create new players
            let inputs = input_line.split(" ").collect::<Vec<_>>();
            self.n = parse_input!(inputs[0], usize);
            self.p = parse_input!(inputs[1], usize);
            new_node.borrow_mut().players = vec![Player::new(); self.n as usize];
        }

        // input player data
        for i in 0..self.n as usize {
            let mut input_line = String::new();
            io::stdin().read_line(&mut input_line).unwrap();
            let inputs = input_line.split(" ").collect::<Vec<_>>();

            let ref mut player = new_node.borrow_mut().players[i];
            player.x0 = parse_input!(inputs[0], i32);
            player.y0 = parse_input!(inputs[1], i32);
            player.x1 = parse_input!(inputs[2], i32);
            player.y1 = parse_input!(inputs[3], i32);
        }
        for i in 0..self.n as usize {
            let x0 = new_node.borrow().players[i].x0;
            let y0 = new_node.borrow().players[i].y0;
            let x1 = new_node.borrow().players[i].x1;
            let y1 = new_node.borrow().players[i].y1;
            for j in 0..self.n as usize {
                new_node.borrow_mut().players[j].locked_field[x0 as usize][y0 as usize] = true;
                new_node.borrow_mut().players[j].locked_field[x1 as usize][y1 as usize] = true;
            }
        }

        self.head = Some(new_node);
    }

    fn search(&mut self) -> Node {
        let mut heap = BinaryHeap::new();

        heap.push(self.head.clone().unwrap());
        for _ in 0..SEARCH_DEPTH {
            let mut tmp = BinaryHeap::new();
            for _ in 0..BEAM_WIDTH {
                if let Some(data) = heap.pop() {
                    tmp.push(data);
                } else {
                    break;
                }
            }
            heap.clear();

            while let Some(current_node) = tmp.pop() {
                for i in 0..DX.len() {
                    let next_node = Node::new(Some(current_node.clone()));
                    next_node.borrow_mut().players[self.p].x1 += DX[i];
                    next_node.borrow_mut().players[self.p].y1 += DY[i];
                    next_node.borrow_mut().output = OUTPUT[i].to_string();
                    let score = next_node.borrow_mut().eval_player(self.p);
                    next_node.borrow_mut().score += score;
                    heap.push(next_node.clone());
                }
            }
        }

        // the top of heap is the best result after search
        let mut node = heap.pop().unwrap();
        // traverse until the next of head
        loop {
            let next_node;
            match node.borrow_mut().parent.take() {
                Some(parent) => {
                    next_node = parent;
                }
                None => { panic!(); }
            }
            if next_node == self.head.clone().unwrap() {
                break;
            }
            node = next_node;
        }
        heap.clear();
        Rc::try_unwrap(node).ok().unwrap().into_inner()
    }
}

impl Node {
    fn new(parent: Option<Rc<RefNode>>) -> Rc<RefNode> {
        let new_node = if let Some(parent) = parent {
            // create child
            Node {
                parent: Some(parent.clone()),
                score: parent.borrow().score,
                players: parent.borrow().players.clone(),
                output: String::new(),
            }
        } else {
            // create the first
            Node {
                parent: None,
                score: 0,
                players: vec![],
                output: String::new(),
            }
        };

        Rc::new(RefNode(RefCell::new(new_node)))
    }

    fn eval_player(&mut self, p: usize) -> i32 {
        let x = self.players[p].x1;
        let y = self.players[p].y1;
        if self.players[p].can_move(x, y) && self.score >= 0 {
            self.players[p].locked_field[x as usize][y as usize] = true;
            0
        } else {
            -1
        }
    }
}

impl Player {
    fn new() -> Self {
        Player {
            x0: -1, y0: -1, x1: -1, y1: -1,
            locked_field: [[false; ROW as usize]; COL as usize]
        }
    }

    fn can_move(&self, x: i32, y: i32) -> bool {
        if x >= 0 && x < COL && y >= 0 && y < ROW && !self.locked_field[x as usize][y as usize] {
            true
        } else {
            false
        }
    }
}


fn main() {
    let mut game = Game::new();

    loop {
        game.input();
        let ans = game.search();
        println!("{}", ans.output);
    }
}

感想

コンパイラがきちんと指摘してくれるの,それはそうなんだけど,アルゴリズムの部分とかが間違ってると結局正常には動かなくて逆にうるさく指摘される分そちらに注力してしまってアルゴリズムに集中できないような気がする.ただコンパイラのチェックを通るとその部分に対しては変なバグが発生することがなくて安心.

まぁ慣れればそこで間違えることはなくなるんだろう(実際今回も書いてくうちにどうすればいいかなんとなくわかってミスは減った(気がする)). けどやっぱ学習曲線が険しいなぁと.

rubyのpryとか,Haskellのghciみたいに,型とかライフタイムとか借用周りとか今どうなってるのか見れるREPLが欲しい.

たぶんもっと慣れてる人は,そのへんをほとんど無意識に適用できるようになってる気がするので,そういう人達の知見をグラフィカルな形で初心者にもわかりやすく説明してくれるツールがほしいんじゃ

ライフタイムとか,所有権がどこに行くかとかどこにあるかとかアニメーションで表示したりできないのかなと.

まとめると,Rust書いてて面白いし安全に極振りしてるの重要視されないかもだけど実は大事なのでもっとメジャーになるといいなと思いました.