RustでLinkedListを実装してみた

TL;DR

最近、データ構造とアルゴリズムをRustで始めてみようかなと思っていて、双方向連結リスト(DLList)を実装していたのですが、難しかったのでメモ。基本的には公式のLinkedListを見ながら実装しました。

要件

method名挙動
get(i)i番目の要素を見る
set(i, x)i番目の要素をxにする
add(i, x)i番目にxを加える
remove(i)i番目の要素を削除する

の4つのメソッドを実装します。

実装

構造体の定義

公式はNonNullを使っていたので真似しました。unsafe使いまくる必要があるので怖いんですけど、公式がこうしていたのでunsafeを使うことにしました。RefCell, Rc, Weakを使うのも良いというのを何件か見かけました。unsafe使わない実装もやってみたいですね。あとPhantomDataなにってなってるので勉強します。

use std::{marker::PhantomData, ptr::NonNull};

#[derive(Debug, Clone)]
pub struct Node<T> {
    x: T,
    next: Option<NonNull<Node<T>>>,
    prev: Option<NonNull<Node<T>>>,
}

#[derive(Debug)]
pub struct DLList<T> {
    head: Option<NonNull<Node<T>>>,
    tail: Option<NonNull<Node<T>>>,
    n: usize,
    marker: PhantomData<Box<Node<T>>>,
}

getとsetの定義

最初に実装したいのはi番目のNonNull<Node<T>>を持ってくる処理です。公式をよく見ると、Cursor的な構造体を定義しています。現在のindexとその時点のポインタを持っていて、nextprevを移動できる感じのやつです。公式だとCursorMutも定義していてそっちのほうがRustっぽい気もしたんですが、なくても今回の要件的には問題なさそうだったので省略しています。

pub struct Cursor<'a, T: 'a> {
    index: usize,
    current: Option<NonNull<Node<T>>>,
    list: &'a DLList<T>,
}

impl<'a, T> Cursor<'a, T> {
    pub fn move_next(&mut self) {
        match self.current.take() {
            Some(cur) => unsafe {
                self.current = cur.as_ref().next;
                self.index += 1;
            },
            None => {
                self.current = self.list.head;
                self.index = 0;
            }
        }
    }

    pub fn move_prev(&mut self) {
        match self.current.take() {
            Some(cur) => unsafe {
                self.current = cur.as_ref().prev;
                self.index -= 1;
            },
            None => {
                self.current = self.list.tail;
                self.index = self.index.checked_sub(1).unwrap_or_else(|| self.list.n);
            }
        }
    }
}

Cursorを使ってi番目のノードのポインタを持ってくる処理を書きます。headとtailの近いほうから見に行きます。getsetはこのポインタから値を取り出すか変更する処理なんですぐ書けます。

impl<T> DLList<T> {
    fn get_node(&self, index: usize) -> Option<NonNull<Node<T>>> {
        if index < self.n / 2 {
            let mut cursor = Cursor {
                index: 0,
                current: self.head,
                list: &self,
            };

            for _ in 0..index {
                cursor.move_next();
            }

            cursor.current
        } else {
            let mut cursor = Cursor {
                index: self.n,
                current: self.tail,
                list: &self,
            };

            for _ in index..self.n - 1 {
                cursor.move_prev();
            }

            cursor.current
        }
    }

    pub fn get(&self, index: usize) -> Option<&T> {
        unsafe { self.get_node(index).map(|n| &(*n.as_ptr()).x) }
    }

    pub fn set(&mut self, x: T, index: usize) {
        unsafe {
            self.get_node(index).map(|mut n| n.as_mut().x = x);
        }
    }
}

addの定義

addを行うためには、入れたいノード(e)の前に新しいノードを入れればいいです。つまり

e.prev -> e -> e.next
e.prev -> added -> e -> e.next

のような感じです。この処理の公式の実装はsplice_nodesというメソッドで実装されているようです。この実装はもうちょっと複雑で、他のリストがあったとして、そのheadとtailのポインタを使うことで

e.prev -> added_1 -> added_2 -> ... -> added_n -> e -> e.next

のような処理ができます。ただ、今回は一個入れればいいので、もう少し簡略化して書きます。引数にe.preveaddedをとって挿入するsplice_nodeを実装しました。addに関しては入れたいところのポインタをget_nodeで持ってきて、splice_nodeすればいいです。

impl<T> DLList<T> {
    fn splice_node(
        &mut self,
        existing_prev: Option<NonNull<Node<T>>>,
        existing_next: Option<NonNull<Node<T>>>,
        mut splice_node: NonNull<Node<T>>,
    ) {
        if let Some(mut existing_prev) = existing_prev {
            unsafe {
                existing_prev.as_mut().next = Some(splice_node);
            }
        } else {
            self.head = Some(splice_node)
        }

        if let Some(mut existing_next) = existing_next {
            unsafe {
                existing_next.as_mut().prev = Some(splice_node);
            }
        } else {
            self.tail = Some(splice_node)
        }

        unsafe {
            splice_node.as_mut().prev = existing_prev;
            splice_node.as_mut().next = existing_next;
        }

        self.n += 1;
    }

    pub fn add(&mut self, x: T, index: usize) {
        let current = self.get_node(index);
        unsafe {
            let spliced_node = Box::leak(Box::new(Node::new(x))).into();
            let node_prev = match current {
                None => self.tail,
                Some(node) => node.as_ref().prev,
            };
            self.splice_node(node_prev, current, spliced_node);
        }
    }
}

removeの実装

消したいノードをn_iとすると、以下のようにすればいいことがわかります。

n_i.prev -> n_i -> n_i.next
n_i.prev -> n_i.next

この処理はunlink_nodeという関数で実装されています。やってる処理は上記そのままです。removeaddのときと同じく、i番目のノードのポインタを持ってきて、unlink_nodeに渡すだけですね。

impl<T> DLList<T> {
    fn unlink_node(&mut self, mut node: NonNull<Node<T>>) {
        let node = unsafe { node.as_mut() };

        match node.prev {
            Some(prev) => unsafe { (*prev.as_ptr()).next = node.next },
            None => self.head = node.next,
        };

        match node.next {
            Some(next) => unsafe { (*next.as_ptr()).prev = node.prev },
            None => self.tail = node.prev,
        }

        self.n -= 1;
    }

    pub fn remove(&mut self, index: usize) -> Option<&T> {
        let node = self.get_node(index);

        if let Some(node) = node {
            self.unlink_node(node);
        }

        unsafe { node.map(|n| &(*n.as_ptr()).x) }
    }
}

感想

公式のコードはやっぱり勉強になります。

この記事に関するIssueをGithubで作成する

Read Next