配列をそのままフロントエンド(HTML+Javascript)に転送する。
// 返却用の構造体のために必要 use serde::Serialize; // 値を返す構造体 #[derive(Serialize)] struct ImageData{ image:Vec<u8>, width:u32, height:u32, } // 画像と情報を返す関数 #[tauri::command] fn get_image() -> ImageData { let width:u32 = 256; let height:u32 = 400; // 画像生成 let mut imgrgb = Vec::<u8>::new(); for x in 0..width{ let r = (x % height) as u8; for _ in 0..height{ imgrgb.push(((r as usize ) % 256) as u8); imgrgb.push(0); imgrgb.push(0); } } // 返却 ImageData { image: imgrgb, width: width, height: height, } } fn main() { tauri::Builder::default() .invoke_handler(tauri::generate_handler![get_image]) // get_image関数を登録 .run(tauri::generate_context!()) .expect("error while running Tauri application"); }
<!DOCTYPE html> <html lang="ja"> <head> <meta charset="UTF-8"> <title>Rust+Tauri画像表示</title> <style> canvas#mycanvas { border: 1px solid #000; } </style> <script type="module" src="/main.js" defer></script> </head> <body> <h1>Rust+Tauri画像表示</h1> <button id="ButtonLoadImage">画像表示</button> <canvas id="mycanvas" width="640" height="480"></canvas> </body> </html>
const { invoke } = window.__TAURI__.core; document.getElementById('ButtonLoadImage').addEventListener('click',async()=>{ const mycanvas = document.getElementById('mycanvas'); const ctx = mycanvas.getContext('2d'); const ImageData = await window.__TAURI__.core.invoke("get_image"); const rustArray = new Uint8Array(ImageData.image); // RGB8の配列 const width = ImageData.width; const height = ImageData.height; // Canvasサイズを調整 mycanvas.width = width; mycanvas.height = height; // キャンバスの画像データを作成 const CanvasImageData = ctx.createImageData(width,height); // キャンバスに与えるデータへの参照を取得 const tmp = CanvasImageData.data; // Rust側がRGBを返すので、RGBA形式で格納 for(let i=0;i<width*height;i++){ tmp[i*4+0] = rustArray[i*3+0]; tmp[i*4+1] = rustArray[i*3+1]; tmp[i*4+2] = rustArray[i*3+2]; tmp[i*4+3] = 255; } ctx.putImageData(CanvasImageData,0,0); });
上ではデータをrgbで送ったため、Canvasが対応するrgbaに変換するためにforを回す必要があったが、rgbaであれば直接コピーできる。
const { invoke } = window.__TAURI__.core; document.getElementById('ButtonLoadImage').addEventListener('click',async()=>{ const mycanvas = document.getElementById('mycanvas'); const ctx = mycanvas.getContext('2d'); const ImageData = await window.__TAURI__.core.invoke("get_image"/*,{index}*/ ); const rustArray = new Uint8Array(ImageData.image); // RGB8の配列 const width = ImageData.width; const height = ImageData.height; // Canvasサイズを調整 mycanvas.width = width; mycanvas.height = height; // キャンバスの画像データを作成 const CanvasImageData = ctx.createImageData(width,height); // キャンバスにRustから受け取ったデータをセット CanvasImageData.data.set(rustArray); ctx.putImageData(CanvasImageData,0,0); });
画像を高速に表示しようとするとWebGLを使うのがいいらしいが、いろいろと手間なので、まずはRust側でbase64に変換してimgに設定する方法を試す。高いフレームレートが要求される用途でなければ、十分に機能する。
use base64::{engine::general_purpose, Engine as _}; use std::io::Cursor; use image::codecs::png::PngEncoder; use image::ImageEncoder; // 返却用の構造体のために必要 use serde::Serialize; // 値を返す構造体 #[derive(Serialize)] struct ImageData{ image:String, width:u32, height:u32, } // 画像と情報を返す関数 #[tauri::command] fn get_image(index: usize) -> ImageData { let width:u32 = 256; let height:u32 = 400; // 画像生成 let mut imgrgb = Vec::<u8>::new(); for x in 0..width{ let r = (x % height) as u8; for y in 0..height{ imgrgb.push(((r as usize * index) % 256) as u8); imgrgb.push(0); imgrgb.push(0); } } // PNGデータにエンコード // RGB配列 → PNG形式への変換 let mut buffer = Cursor::new(Vec::new()); let encoder = PngEncoder::new(&mut buffer); encoder.write_image( &imgrgb, width, height, image::ColorType::Rgb8 ).unwrap(); // Base64エンコードして返す // PNG形式 → Base64表現への変換 let base64_image = general_purpose::STANDARD.encode(buffer.get_ref()); // フロントエンドでimgタグに設定できる形式にする let img =format!("data:image/png;base64,{}", base64_image); // 返却 ImageData { image: img, width: width, height: height, } } fn main() { tauri::Builder::default() .invoke_handler(tauri::generate_handler![get_image]) // get_image関数をフロントエンドが呼び出せるようにする .run(tauri::generate_context!()) .expect("error while running Tauri application"); }
[package]
name = "my1st_tauri"
version = "0.1.0"
description = "A Tauri App"
authors = ["you"]
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[build-dependencies]
tauri-build = { version = "2", features = [] }
[dependencies]
tauri = { version = "2", features = [] }
tauri-plugin-shell = "2"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
image = "0.24"
base64 = "0.21"
<!DOCTYPE html> <html lang="ja"> <head> <meta charset="UTF-8"> <title>Rust+Tauri画像表示</title> </head> <body> <h1>Rust+Tauri画像表示</h1> <img id="image-viewer" alt="Image Viewer" style="border: 1px solid black; width: 256px; height: 256px;"> <p>左右キーで再読み込み</p> <script> let currentIndex = 0;
async function updateImage(index) { // Rustから画像データを取得 const imageData = await window.__TAURI__.core.invoke("get_image", { index }); imageview = document.getElementById("image-viewer"); // 表示領域の取得 imageview.src = imageData.image; // 画像を設定 imageview.style.width = `${imageData.width}px`; // 画像の幅を設定 imageview.style.height = `${imageData.height}px`; // 画像の高さを設定 }
// 初期画像を読み込み updateImage(currentIndex);
// キーイベントリスナー document.addEventListener("keydown", (event) => { if (event.key === "ArrowRight") { currentIndex++; updateImage(currentIndex); } else if (event.key === "ArrowLeft") { currentIndex = Math.max(0, currentIndex - 1); updateImage(currentIndex); } });
</script> </body> </html>
久しぶりにTauriを学ぼうと思い、以前の記事を参考にプロジェクトを作ってみた
すると以下のエラーが発生
これはcreate-tauri-appが作成するCargo.tomlが現在インストールされているTauriライブラリのバージョンに対応していないことが由来らしい。
ここで、create-tauri-appのバージョンを調べてみると
npm list -g create-tauri-app
C:\application\node.js\node-v21.6.2-win-x64
`-- (empty)
となっていて、create-tauri-appのバージョンがemptyになっている。このとき、create-tauri-appは必要な時だけ最新版がダウンロードされて使用されるらしい。つまりcreate-tauri-appは最新版が使われるということ。ここで、cargo install tauri-cliを実行してみると
error: failed to compile `tauri-cli v2.1.0`, intermediate artifacts can be found at `C:\User\myuser\AppData\Local\Temp\cargo-installSnryuD`.
To reuse those artifacts with a future compilation, set the environment variable `CARGO_TARGET_DIR` to that path.
Caused by:
package `tauri-cli v2.1.0` cannot be built because it requires rustc 1.77.2 or newer, while the currently active rustc version is 1.74.1
Try re-running cargo install with `--locked`
と、「現在のrustcのバージョンが1.74.1だが、最新のtauri-cliのバージョンはrustc 1.77.2以上を要求する」という理由でtauri-cliのインストールが失敗する。
つまり、プロジェクトを作るcreate-tauri-appのバージョンが常に最新版が使われているが、tauriライブラリの管理などを行うtauri-cliがructc 1.74対応版という古いものなので、最新のプロジェクトをビルドできないということだと思う。
create-tauri-appのバージョンを1.74版に指定すればうまくいきそうだが、低いバージョンを使う意味も今のところないので、rustcのバージョンアップをして対応する。
なおバージョンチェックは
rustc --version
で行う。
これでnpx create-tauri-appでプロジェクトを作成できるようになった。
なお、最新版ではIdentifierを作成時に指定できるので、手動でbundle > identifierを変更する必要もない。
glTranslated等のレガシーな関数が非推奨になっているので移動行列や回転行列や射影行列を自分で計算しないといけない。面倒なのでクレートを使う。いくつかあるが、cgmathがglmに近い使い勝手らしい。
以下のようにCargo.tomlにcgmathを追加する。
[package]
name = "opengltest"
version = "0.1.0"
edition = "2021"
[dependencies]
glutin = "0.26.0"
gl="0.14.0"
cgmath = "0.18.0"
以下、translate,rotate,perspective view matrixの使用例。
pub fn draw_triangle(programid:gl::types::GLuint,vertexbuffer: gl::types::GLuint,colorbuffer:gl::types::GLuint){ unsafe { gl::UseProgram(programid); } let translate_mat = cgmath::Matrix4::from_translation( cgmath::Vector3::new(0.0,0.0,-3.0) ); let rotate_mat = cgmath::Matrix4::from_angle_z(cgmath::Deg(45.0)); let model_mat = rotate_mat * translate_mat; let fov = cgmath::Deg(45.0); let aspect = 1.0; let near = 0.1; let far = 100.0; let proj_mat = cgmath::perspective(fov,aspect,near,far); let proj; let model; unsafe { let proj_location_name = CString::new("projectionMatrix").unwrap(); let model_location_name = CString::new("modelViewMatrix").unwrap(); proj = gl::GetUniformLocation(programid,proj_location_name.as_ptr()); model = gl::GetUniformLocation(programid,model_location_name.as_ptr()); gl::UniformMatrix4fv(proj,1,gl::FALSE,proj_mat.as_ptr()); gl::UniformMatrix4fv(model,1,gl::FALSE,model_mat.as_ptr()); } /* // デバッグ println!("proj location: {}",proj); println!("model location: {}",model); */ unsafe{ gl::EnableVertexAttribArray(0); gl::BindBuffer(gl::ARRAY_BUFFER,vertexbuffer); gl::VertexAttribPointer( 0, 3, gl::FLOAT, gl::FALSE, 0, null() ); } unsafe{ gl::EnableVertexAttribArray(1); gl::BindBuffer(gl::ARRAY_BUFFER,colorbuffer); gl::VertexAttribPointer( 1, 3, gl::FLOAT, gl::FALSE, 0, null() ); } unsafe { gl::DrawArrays(gl::TRIANGLES, 0,3); let err_check = gl::GetError(); if err_check != gl::NO_ERROR { println!("ERROR::: {}\n", err_check); } gl::DisableVertexAttribArray(0); gl::DisableVertexAttribArray(1); } unsafe { gl::UseProgram(0); } }
以前初期化はやったことがあるが、今回はちゃんと三角形を描く。
基本的にglBeginのような古い関数は使えないので、GLSLで書くことになる。
Rust内の文字列をC言語の文字列に変換するところがいちいちややこしい。
use std::ptr::null; use std::ffi::CString; use std::ffi::CStr;
pub fn draw_prepare() ->(gl::types::GLuint, gl::types::GLuint){ // 頂点の定義 let vertices: [f32;9]=[ 0.0, 0.5, 0.0, // 上頂点 -0.5, -0.5, 0.0, // 左下 0.5, -0.5, 0.0, // 右下 ]; let color:[f32;9]=[ 1.0,0.0,0.0, 0.0,1.0,0.0, 0.0,0.0,1.0, ]; let mut vertexbuffer=0; unsafe{ gl::GenBuffers(1,&mut vertexbuffer); gl::BindBuffer(gl::ARRAY_BUFFER,vertexbuffer); gl::BufferData( gl::ARRAY_BUFFER, 3*3*std::mem::size_of::<gl::types::GLfloat>() as gl::types::GLsizeiptr, vertices.as_ptr() as *const _, gl::STATIC_DRAW ); } let mut colorbuffer=0; unsafe{ gl::GenBuffers(1,&mut colorbuffer); gl::BindBuffer(gl::ARRAY_BUFFER,colorbuffer); gl::BufferData( gl::ARRAY_BUFFER, 3*3*std::mem::size_of::<gl::types::GLfloat>() as gl::types::GLsizeiptr, color.as_ptr() as *const _, gl::STATIC_DRAW ); } (vertexbuffer,colorbuffer) }
pub fn prepare_vertex_shader()->gl::types::GLuint{ let mut VertexShaderID = 0; // 頂点シェーダプログラム let vertex_shader_source = "\ #version 460 core layout (location = 0) in vec3 aPos; layout (location = 1) in vec3 incolor; out vec4 vertexColor; uniform mat4 modelViewMatrix; uniform mat4 projectionMatrix; void main() { gl_Position = projectionMatrix * modelViewMatrix * vec4(aPos, 1.0); vertexColor = vec4(incolor, 1.0); } "; let c_src = CString::new(vertex_shader_source).unwrap(); let mut Result:gl::types::GLint = 0; let mut InfoLogLength:i32 = 0; let mut info_log; unsafe{ VertexShaderID = gl::CreateShader(gl::VERTEX_SHADER); gl::ShaderSource( VertexShaderID, 1, &c_src.as_ptr(), std::ptr::null() ); gl::CompileShader(VertexShaderID); // シェーダのチェック gl::GetShaderiv(VertexShaderID,gl::COMPILE_STATUS,&mut Result); gl::GetShaderiv(VertexShaderID,gl::INFO_LOG_LENGTH,&mut InfoLogLength); if InfoLogLength > 0 { info_log = vec![0u8; InfoLogLength as usize]; if Result == gl::FALSE as i32 { gl::GetShaderInfoLog( VertexShaderID, InfoLogLength, std::ptr::null_mut(), info_log.as_mut_ptr() as *mut gl::types::GLchar ); if let Ok(msg) = CStr::from_ptr(info_log.as_ptr() as *const i8).to_str() { println!("Vertex Shader Error :\n {}\n", msg); } } } } VertexShaderID }
pub fn prepare_fragment_shader()->gl::types::GLuint{ let mut FragmentShaderID=0; let fragment_shader_source = "\ #version 460 core out vec4 FragColor; in vec4 vertexColor; void main() { FragColor = vertexColor; } "; let c_str = CString::new(fragment_shader_source).unwrap(); let mut Result:gl::types::GLint=0; let mut InfoLogLength:i32=0; let mut info_log; unsafe { FragmentShaderID = gl::CreateShader(gl::FRAGMENT_SHADER); gl::ShaderSource( FragmentShaderID, 1, &c_str.as_ptr(), null() ); gl::CompileShader(FragmentShaderID); // フラグメントシェーダ gl::GetShaderiv(FragmentShaderID, gl::COMPILE_STATUS, &mut Result); gl::GetShaderiv(FragmentShaderID, gl::INFO_LOG_LENGTH, &mut InfoLogLength); if InfoLogLength > 0 { if Result == gl::FALSE as i32 { info_log = vec![0u8; InfoLogLength as usize]; gl::GetShaderInfoLog( FragmentShaderID, InfoLogLength, std::ptr::null_mut(), info_log.as_mut_ptr() as *mut gl::types::GLchar ); if let Ok(msg) = CStr::from_ptr(info_log.as_ptr() as *const i8).to_str() { println!("Fragment Shader Error :\n{}\n", msg); } } } } FragmentShaderID }
pub fn link_program(VertexShaderID:gl::types::GLuint,FragmentShaderID:gl::types::GLuint)->gl::types::GLuint{ let mut Result:gl::types::GLint = gl::FALSE as i32; let mut InfoLogLength:i32=0; let mut ProgramID:gl::types::GLuint=0; println!("Linking program"); unsafe{ ProgramID = gl::CreateProgram(); gl::AttachShader(ProgramID,VertexShaderID); gl::AttachShader(ProgramID,FragmentShaderID); gl::LinkProgram(ProgramID); gl::GetProgramiv(ProgramID,gl::LINK_STATUS,&mut Result); gl::GetProgramiv(ProgramID,gl::INFO_LOG_LENGTH,&mut InfoLogLength); if InfoLogLength > 0 { let mut ProgramErrorMessage = vec![0u8; InfoLogLength as usize]; gl::GetProgramInfoLog( ProgramID, InfoLogLength, std::ptr::null_mut(), ProgramErrorMessage.as_mut_ptr() as *mut gl::types::GLchar ); if let Ok(msg) = CStr::from_ptr(ProgramErrorMessage.as_ptr() as *const i8).to_str() { println!("Program Link Error:\n{}\n",msg); } } } ProgramID }
pub fn draw_triangle(programid:gl::types::GLuint,vertexbuffer: gl::types::GLuint,colorbuffer:gl::types::GLuint){ unsafe { gl::UseProgram(programid); } let proj_mat:[f32;16]=[ 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, ]; let model_mat:[f32;16]=[ 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, ]; let proj; let model; unsafe { let proj_location_name = CString::new("projectionMatrix").unwrap(); let model_location_name = CString::new("modelViewMatrix").unwrap(); proj = gl::GetUniformLocation(programid,proj_location_name.as_ptr()); model = gl::GetUniformLocation(programid,model_location_name.as_ptr()); gl::UniformMatrix4fv(proj,1,gl::FALSE,proj_mat.as_ptr()); gl::UniformMatrix4fv(model,1,gl::FALSE,model_mat.as_ptr()); } /* // デバッグ println!("proj location: {}",proj); println!("model location: {}",model); */ unsafe{ gl::EnableVertexAttribArray(0); gl::BindBuffer(gl::ARRAY_BUFFER,vertexbuffer); gl::VertexAttribPointer( 0, 3, gl::FLOAT, gl::FALSE, 0, null() ); } unsafe{ gl::EnableVertexAttribArray(1); gl::BindBuffer(gl::ARRAY_BUFFER,colorbuffer); gl::VertexAttribPointer( 1, 3, gl::FLOAT, gl::FALSE, 0, null() ); } unsafe { gl::DrawArrays(gl::TRIANGLES, 0,3); let err_check = gl::GetError(); if err_check != gl::NO_ERROR { println!("ERROR::: {}\n", err_check); } gl::DisableVertexAttribArray(0); gl::DisableVertexAttribArray(1); } unsafe { gl::UseProgram(0); } }
// 自作ファイル draw.rs の使用 mod draw; // 自作関数を呼び出せるようにする use draw::draw_triangle; use draw::draw_prepare; use draw::prepare_vertex_shader; use draw::prepare_fragment_shader; use draw::link_program; fn main() { let event_loop = glutin::event_loop::EventLoop::new(); let window = glutin::window::WindowBuilder::new().with_title("Rust glutin OpenGL"); // GLのコンテキストを作成 // この gl_context は ContextBuilder<NotCurrent,Window> 型 let gl_context = glutin::ContextBuilder::new() // コアプロファイルを (3, 1) で3.1 に指定。3.2以上だとVAOが必須になりコードの追記が必要なため .with_gl(glutin::GlRequest::Specific(glutin::Api::OpenGl, (3, 1))) .build_windowed(window, &event_loop) .expect("Cannot create context"); // Rustにはシャドーイングがあるので、同じ名前の別変数を同じスコープ内に定義できる。 // この gl_context は ContextBuilder<PossibleCurrent,Window> 型 // 以降、gl_currentは以前の型の意味で用いることはできない let gl_context = unsafe { gl_context .make_current() .expect("Failed to make context current") }; // OpenGLの各関数を初期化 // これでgl::Viewportをはじめとする各関数の関数ポインタにアドレスが設定され、呼び出せるようになる。 gl::load_with(|symbol| gl_context.get_proc_address(symbol) as *const _); /////////////////////////////////////////////////////////////////////////////////// let (vbuffer,cbuffer) = draw_prepare(); // 座標と色のバッファを定義 let VertexShaderID:gl::types::GLuint = prepare_vertex_shader(); // バーテクスシェーダ作成 let FragmentShaderID:gl::types::GLuint = prepare_fragment_shader(); // フラグメントシェーダ作成 let ProgramID:gl::types::GLuint = link_program(VertexShaderID,FragmentShaderID); // プログラムのリンク /////////////////////////////////////////////////////////////////////////////////// // event_loopを開始する event_loop.run(move |event, _, control_flow| { // Pollを指定するとループが走り続ける // Waitを指定するとイベントが発生したときだけループが動く。 *control_flow = glutin::event_loop::ControlFlow::Wait; match event { glutin::event::Event::WindowEvent { event, .. } => match event { //////////////////////////////////////////////////////////////////// // ウィンドウを閉じる glutin::event::WindowEvent::CloseRequested => { // ウィンドウが閉じられた場合、event_loopを停止する *control_flow = glutin::event_loop::ControlFlow::Exit; }, //////////////////////////////////////////////////////////////////// // ウィンドウのサイズを変更 glutin::event::WindowEvent::Resized(new_size) => { // ビューポート再設定 unsafe { gl::Viewport(0, 0, new_size.width as i32, new_size.height as i32); } }, //////////////////////////////////////////////////////////////////// _ => (), }, _ => (), } // 描画 unsafe { gl::Clear(gl::COLOR_BUFFER_BIT | gl::DEPTH_BUFFER_BIT); gl::ClearColor(0.3, 0.3, 0.5, 1.0); gl::Disable(gl::DEPTH_TEST); gl::Disable(gl::CULL_FACE); // 描画 draw_triangle(ProgramID,vbuffer,cbuffer); gl::Flush(); } // スワップバッファ gl_context.swap_buffers().unwrap(); }); }
Rust用の統合開発環境RustRoverが、非商用に限り無料で使用できるらしい。
インストール後、起動するとライセンスの種類を選択できるので、無料ライセンスを選ぶ。
新規プロジェクトを作る。
実行。特に何もしなくても、▷ボタンで実行できる。
デバッグ用のコンソールはUnicodeに対応しているらしく絵文字もそのまま表示できる。
昔RustでWin32APIを使ってみたが、それも何の工夫もせずにちゃんと実行できた。
個人的に極めて重要なのが、Ctrl+ホイールでエディタの文字サイズ変更がどうしても欲しい。あとテーマはLightがいい。
エディター→一般→インライン補完→ローカルのFull Line補完候補を有効にするのチェックを外す。
これが入っていると調べなくても書けてしまうので上達しない。
Pythonの文字列はutf8で扱われているが、C++側はUnicodeに弱い。使用しているライブラリがマルチバイト(非utf8なstd::string)しか受け付けてないような場合はPython側でshift_jisに変換して渡してやると都合がいい。
#include <pybind11/pybind11.h> #include <cstdio> namespace py = pybind11; // 名前空間を指定 void func_setstring(std::string text) { FILE* fp = fopen("test.txt", "wb"); if (fp == NULL) { return; } fwrite(text.c_str(), sizeof(char), text.size(), fp); fclose(fp); } // PYBIND11_MODULE(モジュール名, モジュール変数名) // 生成物はmy_module_string.pydという名前にしなければならない PYBIND11_MODULE(my_module_string, m) { // 関数を定義 m.def("func_string", &func_setstring, "Write text to file", py::arg("text")); }
import my_module_string text = "こんにちは、世界!" # テキストを書き込む my_module_string.func_string(text.encode("shift_jis"))
エディタで確認すると、SJISと判別されている。
そのまま渡すとutf-8で受け取る。C++はstd::stringでutf8を扱いたがるので紛らわしい。高度な文字列処理が必要ないならそのまま使用できる。ファイルに保存してやれば受け取った内容がutf-8であることをテキストエディタで確認できる。
import my_module_string text = "こんにちは、世界!" # テキストを書き込む my_module_string.func_string(text)
class mykdtree { std::vector<Eigen::Vector3d>& _points; node* node_root; public: mykdtree(std::vector<Eigen::Vector3d>& points) : _points(points) {} // ツリーの構築 void build() { // 最初のindexリストを作成 std::vector<size_t> indices(_points.size()); for (size_t i = 0; i < _points.size(); i++) { indices[i] = i; } // 再帰的にツリー作成 node_root = buildnode(_points, indices, 0/*X軸*/); } node* getRoot() { return node_root; }
// 半径内の頂点を探索する関数 // 再帰的に探索する関数 // nd: 現在のノード // query: 探索する点 // radius: 探索半径 // results: 探索結果のindexリスト void radiusSearch_core( node* nd, const Eigen::Vector3d& query, double radius, std::vector<size_t>& results) { if (nd == nullptr) { return; } const auto& nodePt = _points[nd->index]; // ノード と query の距離を計算 float dist = (query - nodePt).norm(); // 現在のノードがqueryを中心とした探索範囲内にあるか確認 if (dist <= radius) { results.push_back(nd->index);//現在のノードを結果に追加 } // サブツリーの探索条件 // 軸に沿った距離を計算 /////////////////////////////////////////////////////////// // // | // ------------●-------------○---------------- // query nodePt // | // // query.x - nodePt.x < 0 // queryが現在のノードの左側にある → nodePtの左側を探索 /////////////////////////////////////////////////////////// // // | // ------------○-------------●---------------- // nodePt query // | // // query.x - nodePt.x > 0 // queryが現在のノードの右側にある → nodePtの右側を探索 /////////////////////////////////////////////////////////// float axisDist = query[nd->axis] - nodePt[nd->axis]; // ケース1 queryがノードの左側にある if (axisDist < 0) { // 左側にあるので、左側を探索 radiusSearch_core(nd->left, query, radius, results); // ケース1(2)nodePtが、queryの半径以内にあるなら、右側にもあるかもしれない if (axisDist * axisDist <= radius * radius) { radiusSearch_core(nd->right, query, radius, results); } } else { // ケース2 queryがノードの右側にある // 右側にあるので、右側を探索 radiusSearch_core(nd->right, query, radius, results); // ケース2(2)nodePtが、queryの半径以内にあるなら、左側にもあるかもしれない if (axisDist * axisDist <= radius * radius) { radiusSearch_core(nd->left, query, radius, results); } } }
// 半径内の頂点を探索する関数 // query: 探索する点 // radius: 探索半径 // results: 探索結果のindexリスト void radiusSearch( const Eigen::Vector3d& query, double radius, std::vector<size_t>& results) { results.clear(); radiusSearch_core(node_root, query, radius, results); }
};
#include <iostream> //VTK_MODULE_INITに必要 #include <vtkAutoInit.h> #include <vtkSmartPointer.h> #include <vtkRenderer.h> #include <vtkRenderWindow.h> #include <vtkRenderWindowInteractor.h> //円筒とその表示に必要 #include <vtkPolyDataMapper.h> #pragma comment(lib,"opengl32.lib") #pragma comment(lib,"psapi.lib") #pragma comment(lib,"dbghelp.lib") #pragma comment(lib,"ws2_32.lib") #include <Eigen/Core> #include <array> #include <vtkActor.h> #include <vtkPoints.h> #include <vtkPolyData.h> #include <vtkUnsignedCharArray.h> #include <vtkPointData.h> #include <vtkVertexGlyphFilter.h> #include <vtkProperty.h> #include "mykdtree.hpp" //必須 VTK_MODULE_INIT(vtkRenderingOpenGL2); VTK_MODULE_INIT(vtkInteractionStyle); // VTK表示用 struct MyVtkCloud { std::vector<Eigen::Vector3d> points; std::array<unsigned char, 3> color; std::vector<std::array<unsigned char, 3> > color_array; vtkSmartPointer<vtkActor> actor; void makeActor() { // VTKのデータ構造に変換 vtkSmartPointer<vtkPoints> vtk_points = vtkSmartPointer<vtkPoints>::New(); vtkSmartPointer<vtkUnsignedCharArray> vtk_colors = vtkSmartPointer<vtkUnsignedCharArray>::New(); vtk_colors->SetNumberOfComponents(3); // RGB vtk_colors->SetName("Colors"); for (size_t i = 0; i < points.size(); ++i) { // 点を追加 vtk_points->InsertNextPoint(points[i].x(), points[i].y(), points[i].z()); if (color_array.size() == 0) { vtk_colors->InsertNextTypedTuple(color.data()); } else { vtk_colors->InsertNextTypedTuple(color_array[i].data()); } } //////////////////////////////////////////////////////////// vtkSmartPointer<vtkPolyData> polyData = vtkSmartPointer<vtkPolyData>::New(); polyData->SetPoints(vtk_points); polyData->GetPointData()->SetScalars(vtk_colors); //////////////////////////////////////////////////////////// vtkSmartPointer<vtkVertexGlyphFilter> vertexFilter = vtkSmartPointer<vtkVertexGlyphFilter>::New(); vertexFilter->SetInputData(polyData); vertexFilter->Update(); //////////////////////////////////////////////////////////// vtkSmartPointer<vtkPolyDataMapper> mapper = vtkSmartPointer<vtkPolyDataMapper>::New(); mapper->SetInputConnection(vertexFilter->GetOutputPort()); vtkSmartPointer<vtkActor> actor = vtkSmartPointer<vtkActor>::New(); actor->SetMapper(mapper); // 頂点サイズを指定 actor->GetProperty()->SetPointSize(5); // ここで頂点サイズを指定します this->actor = actor; } };
int main(int /*argc*/, char** /*argv*/) { MyVtkCloud cloud1; std::vector<std::array<unsigned char, 3> > color_array; srand((unsigned int)5); //// ランダムな点群を作成 for (size_t i = 0; i < 10000; ++i) { cloud1.points.push_back(Eigen::Vector3d::Random()); color_array.push_back({ 255,0,0 }); } cloud1.color = std::array<unsigned char, 3>{ 0, 255, 0 }; cloud1.makeActor(); mykdtree kdtree(cloud1.points); kdtree.build(); node* root = kdtree.getRoot(); std::vector<size_t> indices; kdtree.radiusSearch( Eigen::Vector3d(0.3,0.3,0.3), 0.3, indices); for(auto i:indices){ color_array[i] = { 0,255,0 }; } cloud1.color_array = color_array; cloud1.makeActor(); // 表示 std::vector<size_t> indices_look = indices; // 昇順ソート std::sort(indices_look.begin(), indices_look.end()); // コンソールに表示 for (auto i : indices_look) { std::cout << i << std::endl; } ////////////////////////////////////// auto renderer = vtkSmartPointer<vtkRenderer>::New(); renderer->AddActor(cloud1.actor); renderer->ResetCamera(); ////////////////////////////////////// auto interactor = vtkSmartPointer<vtkRenderWindowInteractor>::New(); ////////////////////////////////////// auto renderWindow = vtkSmartPointer<vtkRenderWindow>::New(); renderWindow->AddRenderer(renderer); renderWindow->SetInteractor(interactor); renderWindow->Render(); interactor->Start(); //イベントループへ入る return 0; }
partitionを使い、quick selectアルゴリズムを書く。quick selectでは「配列がソートされたときに、n番目に来る値」を取得できる。これを利用して中央値を取得できる。n番目の要素を特定するのに必要な処理が完了したらクイックソートを中断することで実現するので、データの順番が入れ替わるが、全部ソートするほどの時間はかからない。
C++でpartitionするプログラムを変更して[pivot未満][pivot][pivotより大きい]の領域に分割する
#include <iostream> #include <vector> #include <algorithm> #include "partition_3area.hpp"
void quick_select(std::vector<PairTypeFloat>& arr,size_t low,size_t high, size_t target_index) { size_t nlow = low; size_t nhigh = high; size_t tmp_pivot_index = low + (nhigh - nlow) / 2; //なにかしらpivotを決める // pivotでパーティション PartitionFuncFloat cond(arr, tmp_pivot_index); std::array<int64_t, 3> ret= partition_3area(arr, nlow, nhigh, cond); nlow = ret[0]; // 左側開始位置 size_t newpivot = ret[1]; // 処理後のpivotの位置 nhigh = ret[2]; // 右側終了位置 // [nlow~newpivot-1][newpivot][newpivot+1~nhigh] に分割される // このとき、 if (newpivot == target_index) { // [nlow~newpivot-1][target_index][newpivot+1~nhigh] であれば、 // target_indexの要素番号のデータはtarget_index番目の要素になっている return ; } else if (target_index < newpivot) { // [target_indexを含む][newpivot][newpivot+1~nhigh] であれば、 // 左側を処理 quick_select(arr, nlow, newpivot, target_index); } else { // [nlow~newpivot-1][newpivot][target_indexを含む] であれば、 // 右側を処理 quick_select(arr, newpivot+1, nhigh, target_index); } }
int main() { std::vector<PairTypeFloat> arr; arr.push_back(std::make_pair(3.7, "a")); arr.push_back(std::make_pair(2.3, "b")); arr.push_back(std::make_pair(7.5, "c")); arr.push_back(std::make_pair(1.2, "d")); arr.push_back(std::make_pair(6.8, "e")); arr.push_back(std::make_pair(5.6, "f")); arr.push_back(std::make_pair(2.4, "g")); arr.push_back(std::make_pair(0.9, "h")); arr.push_back(std::make_pair(6.1, "i")); arr.push_back(std::make_pair(5.8, "j"));// arr.push_back(std::make_pair(4.5, "k")); arr.push_back(std::make_pair(8.2, "l")); arr.push_back(std::make_pair(9.1, "m")); arr.push_back(std::make_pair(1.7, "n")); arr.push_back(std::make_pair(5.3, "o")); arr.push_back(std::make_pair(7.7, "p")); arr.push_back(std::make_pair(3.8, "q")); arr.push_back(std::make_pair(4.2, "r")); std::cout << "処理前の配列: " << std::endl; for (size_t i = 0; i < arr.size(); i++) { std::cout << "[" << i << "] " << arr[i].first << " " << arr[i].second << std::endl; } size_t target_index = arr.size() / 2; std::cout << target_index << "番目の要素: " << arr[target_index].first << " " << arr[target_index].second << std::endl; quick_select(arr, 0, arr.size(), target_index);// 中央値を求める std::cout << "--------------------------------------------" << std::endl; std::cout << "処理後の配列: " << std::endl; for (size_t i = 0; i < arr.size(); i++) { std::cout << "[" << i << "] " << arr[i].first << " " << arr[i].second << std::endl; } std::cout << target_index << "番目の要素: " << arr[target_index].first << " " << arr[target_index].second << std::endl; std::cout << "--------------------------------------------" << std::endl; std::cout << "ソート結果: " << std::endl; arr.clear(); arr.push_back(std::make_pair(3.7, "a")); arr.push_back(std::make_pair(2.3, "b")); arr.push_back(std::make_pair(7.5, "c")); arr.push_back(std::make_pair(1.2, "d")); arr.push_back(std::make_pair(6.8, "e")); arr.push_back(std::make_pair(5.6, "f")); arr.push_back(std::make_pair(2.4, "g")); arr.push_back(std::make_pair(0.9, "h")); arr.push_back(std::make_pair(6.1, "i")); arr.push_back(std::make_pair(5.8, "j")); // arr.push_back(std::make_pair(4.5, "k")); arr.push_back(std::make_pair(8.2, "l")); arr.push_back(std::make_pair(9.1, "m")); arr.push_back(std::make_pair(1.7, "n")); arr.push_back(std::make_pair(5.3, "o")); arr.push_back(std::make_pair(7.7, "p")); arr.push_back(std::make_pair(3.8, "q")); arr.push_back(std::make_pair(4.2, "r")); std::sort(arr.begin(), arr.end(), [](const PairTypeFloat& a, const PairTypeFloat& b) {return a.first < b.first; }); for (size_t i = 0; i < arr.size(); i++) { std::cout << "[" << i << "] " << arr[i].first << " " << arr[i].second << std::endl; } }
#include <iostream> #include <vector> #include <array>
// 配列の分割を行う関数 template<typename Container, class PCond> int64_t partition(Container& arr, int64_t low, int64_t high, PCond& pivot) { int64_t i = low; int64_t j = high-1; int64_t bound; while (true) { // 条件に合致するものを探す while ( (pivot.IsTrue(i) == true) ) { i++; if (i >= high) break; } // 条件に合致しないものを探す while ( (pivot.IsFalse(j) == true) ) { j--; if (j == -1) break; } // 左右の走査が交差した場合に分割終了 if (i >= j) { bound = i; break; } // 条件に合致しない要素とする要素を交換 pivot.Swap(i, j); // 走査を進める i++; j--; } return bound; }
enum BoolType { True = 1, False = 0 }; using PairTypeBool = std::pair<BoolType, std::string>; using PairTypeFloat = std::pair<float, std::string>;
// データを分割する時に必要な関数をまとめたもの struct PartitionFuncBool { std::vector<PairTypeBool>& _array; int64_t _pivot; float _value; PartitionFuncBool(std::vector<PairTypeBool>& arr,int64_t pivot_index):_array(arr) { _pivot = pivot_index; _value = _array[_pivot].first; } // 要素の交換関数 void Swap(size_t i, size_t j) { // pivotが移動したので最新のindexを保存 if(i==_pivot) _pivot = j; else if(j==_pivot) _pivot = i; std::swap(_array[i], _array[j]); } // 条件判定関数 bool IsTrue(size_t i) { return _array[i].first == _value; } bool IsFalse(size_t i) { return !IsTrue(i); } };
// データを分割する時に必要な関数をまとめたもの struct PartitionFuncFloat { std::vector<PairTypeFloat>& _array; int64_t _pivot; float _value; PartitionFuncFloat(std::vector<PairTypeFloat>& arr, int64_t pivot_index) :_array(arr) { _pivot = pivot_index; _value = _array[_pivot].first; } // 要素の交換関数 void Swap(size_t i, size_t j) { // pivotが移動したので最新のindexを保存 if (i == _pivot) _pivot = j; else if (j == _pivot) _pivot = i; std::swap(_array[i], _array[j]); } // 条件判定関数 bool IsTrue(size_t i) { return _array[i].first <= _value; } bool IsFalse(size_t i) { return !IsTrue(i); } };
template<typename Container, class PCond> std::array<int64_t, 3> partition_3area(Container& arr, int64_t low, int64_t high, PCond& cond) { int64_t bound_idx = partition< Container, PCond>(arr, low, high, cond); // cond._pivot ... 移動後のpivotのindex 前半にあるはず // bound_idx ... 条件に合致しない要素の最初のindex 後半にあるはず // pivotが前半の末尾以外にある場合は、前半末尾へ移動 if(cond._pivot != bound_idx-1) cond.Swap(cond._pivot, bound_idx-1); return { low, bound_idx-1,// 移動後のpivot high }; }
void testbool() { std::vector<PairTypeBool> arr; arr.push_back(std::make_pair(True, "a")); arr.push_back(std::make_pair(False, "b")); arr.push_back(std::make_pair(False, "c")); arr.push_back(std::make_pair(True, "d")); arr.push_back(std::make_pair(True, "e")); arr.push_back(std::make_pair(True, "f")); arr.push_back(std::make_pair(False, "g")); arr.push_back(std::make_pair(True, "h")); arr.push_back(std::make_pair(False, "i")); arr.push_back(std::make_pair(True, "j")); std::cout << "分割前の配列: " << std::endl; for (size_t i = 0; i < arr.size(); i++) { std::cout << "[" << i << "] " << arr[i].first << " " << arr[i].second << std::endl; } PartitionFuncBool cond(arr, arr.size() / 2); // ピボットの値を取得 // ピボット表示 std::cout << "ピボット: " << "[" << cond._pivot << "]" << arr[cond._pivot].first << " " << arr[cond._pivot].second << std::endl; auto ret = partition_3area(arr, (size_t)0, arr.size(), cond); std::cout << "分割後の配列: " << std::endl; for (size_t i = ret[0]; i < ret[1]; i++) { std::cout << "[" << i << "] " << arr[i].first << " " << arr[i].second << std::endl; } std::cout << std::endl; std::cout << "[" << ret[1] << "] " << arr[ret[1]].first << " " << arr[ret[1]].second << std::endl; std::cout << std::endl; for (size_t i = ret[1] + 1; i < ret[2]; i++) { std::cout << "[" << i << "] " << arr[i].first << " " << arr[i].second << std::endl; } }
void testfloat() { std::vector<PairTypeFloat> arr; arr.push_back(std::make_pair(3.7 , "a")); arr.push_back(std::make_pair(2.3 , "b")); arr.push_back(std::make_pair(7.5 , "c")); arr.push_back(std::make_pair(1.2 , "d")); arr.push_back(std::make_pair(6.8 , "e")); arr.push_back(std::make_pair(5.6 , "f")); arr.push_back(std::make_pair(2.4 , "g")); arr.push_back(std::make_pair(0.9 , "h")); arr.push_back(std::make_pair(6.1 , "i")); arr.push_back(std::make_pair(5.3 , "j")); arr.push_back(std::make_pair(4.5 , "k")); arr.push_back(std::make_pair(8.2 , "l")); arr.push_back(std::make_pair(9.1 , "m")); arr.push_back(std::make_pair(1.7 , "n")); arr.push_back(std::make_pair(5.3 , "o")); arr.push_back(std::make_pair(7.7 , "p")); arr.push_back(std::make_pair(3.8 , "q")); arr.push_back(std::make_pair(4.2 , "r")); std::cout << "分割前の配列: " << std::endl; for (size_t i = 0; i < arr.size(); i++) { std::cout << "[" << i << "] " << arr[i].first << " " << arr[i].second << std::endl; } PartitionFuncFloat cond(arr, arr.size()/2); // ピボットの値を取得 // ピボット表示 std::cout << "ピボット: " << "[" << cond._pivot << "]" << arr[cond._pivot].first << " " << arr[cond._pivot].second << std::endl; auto ret = partition_3area(arr, (size_t)0, arr.size(), cond); std::cout << "分割後の配列: " << std::endl; for (size_t i = ret[0]; i < ret[1]; i++) { std::cout << "[" << i << "] " << arr[i].first << " " << arr[i].second << std::endl; } std::cout << std::endl; std::cout << "[" << ret[1] << "] " << arr[ret[1]].first << " " << arr[ret[1]].second << std::endl; std::cout << std::endl; for (size_t i = ret[1] + 1; i < ret[2]; i++) { std::cout << "[" << i << "] " << arr[i].first << " " << arr[i].second << std::endl; } }
int main() { testfloat(); std::cout << "-----------------------------------\n"; testbool(); return 0; }
分割前の配列:
[0] 3.7 a
[1] 2.3 b
[2] 7.5 c
[3] 1.2 d
[4] 6.8 e
[5] 5.6 f
[6] 2.4 g
[7] 0.9 h
[8] 6.1 i
[9] 5.3 j
[10] 4.5 k
[11] 8.2 l
[12] 9.1 m
[13] 1.7 n
[14] 5.3 o
[15] 7.7 p
[16] 3.8 q
[17] 4.2 r
ピボット: [9]5.3 j
分割後の配列:
[0] 3.7 a
[1] 2.3 b
[2] 4.2 r
[3] 1.2 d
[4] 3.8 q
[5] 5.3 o
[6] 2.4 g
[7] 0.9 h
[8] 1.7 n
[9] 4.5 k
[10] 5.3 j
[11] 8.2 l
[12] 9.1 m
[13] 6.1 i
[14] 5.6 f
[15] 7.7 p
[16] 6.8 e
[17] 7.5 c
分割前の配列:
[0] 1 a
[1] 0 b
[2] 0 c
[3] 1 d
[4] 1 e
[5] 1 f
[6] 0 g
[7] 1 h
[8] 0 i
[9] 1 j
ピボット: [5]1 f
分割後の配列:
[0] 1 a
[1] 1 j
[2] 1 h
[3] 1 d
[4] 1 e
[5] 1 f
[6] 0 g
[7] 0 c
[8] 0 i
[9] 0 b