ABEJA Tech Blog

中の人の興味のある情報を発信していきます

ロボティクスモデルの精度向上の挑戦 〜前処理モデル追加編(深度推定による精度向上)〜

こちらは「ロボティクスモデルの精度向上の挑戦」の後編記事になります。

前編は以下の記事をご参照ください。

tech-blog.abeja.asia

前編では、データオーギュメントの改善によりモデルの汎化性能を向上させることを実現しましたが、今回の記事では前処理モデルを追加することにより、モデルの Max 品質を向上させることを目指します。

前処理モデル追加によるモデル改善

今回の方法では、コアのAIモデル(今回の場合はロボティクスモデル)とは別の前処理用のAIモデルを追加する方法でのモデル改善を行ってみます

この際のポイントは、正解ではないけど正解のヒントとなる情報を前処理モデルで作成して、コアのAIモデルに追加入力するという点になります。

深度マップ用の前処理モデル追加

今回のロボティクスのタスク(オブジェクトを掴んで別のオブジェクトにはめ込むタスク)では、以下のようなカメラ画像の深度マップの画像も入力すれば、モデルが各オブジェクトの深さ情報も理解しやすくなって、タスクの成功率があがるのではないかと考えられます。

推論パイプライン全体を図示すると以下のようになります。

論文「Learning Fine-Grained Bimanual Manipulation with Low-Cost Hardware」より引用し改変

さてロボティクスモデルに深度マップ画像を入力&解釈できるようにするには、深度マップ画像が付与された学習用データセットが必要になりますが、そんな都合の良いデータセットは公開されていないので、lerobot/aloha_sim_insertion_human_image データセットをベースに以下の手順で自作します。

  1. lerobot/aloha_sim_insertion_human_image データセットをローカルにダウンロード
  2. 各エポックの各時間ステップのデータ抽出
  3. 上記データセットのカメラ画像から深度マップを生成

    今回の例では、この際の深度マップ画像を生成する前処理モデルとして、以下の Depth-Anything-V2 を使用します。 github.com

  4. その他は同じ内容で LeRobot データセット形式にして保存する

参考までに、今回の学習用データセット作成のコード例も乗せておきます

  • コード例

      import argparse
      import os
      import shutil
      from pathlib import Path
      from typing import Any
    
      import cv2
      import lerobot
      import matplotlib.pyplot as plt
      import numpy as np
      import torch
      import torch.nn.functional as F
      from lerobot.common.datasets.lerobot_dataset import (
          LeRobotDataset,
          LeRobotDatasetMetadata
      )
      from matplotlib import cm
      from tqdm import tqdm
    
      # NOTE: git clone https://github.com/DepthAnything/Depth-Anything-V2 and fix path your appropriate directory.
      depth_anything_path = os.path.join(os.path.dirname(__file__), "..", "Depth-Anything-V2")
      import sys
    
      sys.path.append(depth_anything_path)
      from depth_anything_v2.dpt import DepthAnythingV2
    
      def create_dataset(
          model: Any,
          original_dataset_id: str,
          custom_dataset_id: str,
          fps: int = 50,
          robot_type: str = "aloha",
          push_to_hub: bool = False,
      ):
          # get metadata from original dataset
          original_metadata = LeRobotDatasetMetadata(original_dataset_id)
          features = original_metadata.features.copy()
          print("original features:", features)
    
          # add new feature
          features["observation.depths.top"] = {
              "dtype": "image",
              "shape": features["observation.images.top"]["shape"],
              "names": features["observation.images.top"]["names"],
          }
          print("new features:", features)
    
          # create new dataset
          custom_dataset = LeRobotDataset.create(
              repo_id=custom_dataset_id,
              fps=fps,
              robot_type=robot_type,
              features=features,
              use_videos=True,
              image_writer_processes=4,
              image_writer_threads=8,
          )
    
          # load original dataset
          original_dataset = LeRobotDataset(original_dataset_id)
    
          # copy all data
          num_episodes = len(original_dataset.episode_data_index["from"])
    
          for episode_idx in tqdm(range(num_episodes), desc="Creating dataset with episodes"):
              global_frame_start = original_dataset.episode_data_index["from"][
                  episode_idx
              ].item()
              global_frame_end = original_dataset.episode_data_index["to"][episode_idx].item()
              for frame_idx in tqdm(
                  range(global_frame_start, global_frame_end),
                  desc="Creating dataset with frames",
              ):
                  frame = original_dataset[frame_idx]
                  new_frame = {}
    
                  # copy original frame data
                  for k, v in frame.items():
                      if k in [
                          "task_index",
                          "episode_index",
                          "index",
                          "frame_index",
                          "timestamp",
                      ]:
                          continue
                      if (
                          k == "observation.images.top"
                          and isinstance(v, torch.Tensor)
                          and v.shape[0] == 3
                      ):
                          v = v.permute(1, 2, 0)
                      if k in ["next.done"]:
                          v = v.unsqueeze(0)
                      new_frame[k] = v.clone() if hasattr(v, "clone") else v
    
                  img = frame["observation.images.top"]
                  if isinstance(img, torch.Tensor):
                      img = img.permute(1, 2, 0).cpu().numpy()
                  if img.dtype != np.uint8:
                      img = (img * 255).astype(np.uint8)
    
                  # add depth image as new feature data
                  depth_map = model.infer_image(img)
                  depth_map = (depth_map - depth_map.min()) / (
                      depth_map.max() - depth_map.min()
                  )
                  depth_map = np.stack([depth_map, depth_map, depth_map], axis=-1)
                  new_frame["observation.depths.top"] = depth_map
    
                  # add 1 step data (frame)
                  custom_dataset.add_frame(new_frame)
    
              # save episode
              custom_dataset.save_episode()
    
          # push to huggingface
          if push_to_hub:
              custom_dataset.push_to_hub()
    
      if __name__ == "__main__":
          parser = argparse.ArgumentParser()
          parser.add_argument(
              "--original_dataset_id",
              type=str,
              default="lerobot/aloha_sim_insertion_human_image",
          )
          parser.add_argument(
              "--custom_dataset_id",
              type=str,
              default="Yagami360/aloha_sim_insertion_human_with_depth_images_20250619",
          )
          parser.add_argument("--fps", type=int, default=50)
          parser.add_argument("--robot_type", type=str, default="aloha")
          parser.add_argument("--push_to_hub", action="store_true", default=False)
          parser.add_argument(
              "--model_checkpoint_path",
              type=str,
              default="../checkpoints/depth_anything_v2/depth_anything_v2_vitb.pth",
          )
          parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
          args = parser.parse_args()
    
          model_configs = {
              "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
              "vitb": {
                  "encoder": "vitb",
                  "features": 128,
                  "out_channels": [96, 192, 384, 768],
              },
              "vitl": {
                  "encoder": "vitl",
                  "features": 256,
                  "out_channels": [256, 512, 1024, 1024],
              },
              "vitg": {
                  "encoder": "vitg",
                  "features": 384,
                  "out_channels": [1536, 1536, 1536, 1536],
              },
          }
          if args.model_checkpoint_path.endswith("depth_anything_v2_vits.pth"):
              encoder = "vits"
          elif args.model_checkpoint_path.endswith("depth_anything_v2_vitb.pth"):
              encoder = "vitb"
          elif args.model_checkpoint_path.endswith("depth_anything_v2_vitl.pth"):
              encoder = "vitl"
          else:
              raise ValueError(f"Invalid model checkpoint path: {args.model_checkpoint_path}")
    
          model = DepthAnythingV2(**model_configs[encoder])
          model.load_state_dict(torch.load(args.model_checkpoint_path, map_location="cpu"))
          model = model.to(args.device).eval()
    
          custom_dataset = create_dataset(
              model=model,
              original_dataset_id=args.original_dataset_id,
              custom_dataset_id=args.custom_dataset_id,
              fps=args.fps,
              robot_type=args.robot_type,
              push_to_hub=args.push_to_hub,
          )
    

また上記コードで自作したデータセット自体も以下の場所にアップロードしておきました huggingface.co

さて次に、このように作成したデータセットで ACT モデルを学習してます。

通常モデルに別のデータを追加入力する際には、モデル定義の class に入力層を追加したり入力層の次元を増やしたりする必要がありますが、幸い Lerobot の学習コードは、Lerobot 形式のデータセット属性に沿って汎用性に学習できる作りになっているので、モデル内部の実装を変更することなく深度マップ画像を入力できました。

一応 ACT モデル定義の forward メソッド内にデバッグコードを追加して深度マップ画像を concat で結合してモデルに入力されていることも確認済みです。

なお LeRobot の ACT の画像エンコーダーでは、デフォルトで ImageNet で事前学習した Encoder を使用しており、深度マップ画像(グレースケール)の特徴抽出に最適化されていないというのはありますが、Encoder を含めて後段のネットワークとともに学習されるので一定程度深度情報を学習&解釈できるようになると思います。モデルがより強く深度情報を学習&解釈できるようにするためには、別途深度マップ画像専用の Encoder を入力層に追加したり、モデルが学習&解釈できているかを特徴マップのヒートマップ等で可視化したりしながら確認すると良いかもしれません。

学習完了後は、シミュレーター環境上でこのモデルを動かすのですが、今回は深度マップ画像を追加したので、シミュレーターから得られるカメラ画像から前処理モデル(Depth-Anything-V2)で深度マップ画像を生成したうえで動かす必要があります

以下、結果です

モデル 学習ステップ数 成功率(at 50 episode)
改善前モデル(ACT without depth map) 20K 18.00%
改善後モデル(ACT with depth map) 20K 24.00%

予想してたよりは成功率が向上しなかったものの、6%程度成功率が向上しました!

(今回はやってないですが、深度マップ画像専用の Encoder を別途追加するなどすれば更に精度が向上するかもです)

但し、前処理や後処理モデルを追加することでのデメリットとしては、前処理モデルや後処理モデルの推論が追加で必要になるので、推論パイプライン全体での推論時間が伸びる点や各前処理モデルの品質が悪いとパイプライン全体での品質が悪化する点などがあります。

特にロボティクス制御ではリアルタイム性が要求されるケースが多いので、推論パイプライン全体の処理時間は重要になると思います。

今回の記事では、時間の都合上深度マップ画像を追加入力するパターンだけのモデル改善を試しましたが、他にも例えば、セグメンテーションマスク画像やロボットの関節点画像などのヒント情報を前処理モデルで生成して追加入力することでもモデル改善ができると思っています。 また複数の前処理モデルを利用する場合は、それぞれのモデルを並列処理することで推論時間を短縮させるなどの方法もあります。

最後に、今回の検証に使用したシミュレーター環境上での推論コード例も乗せておきます

  • コード例

      import argparse
      import importlib.util
      import os
    
      import cv2
      import gym_aloha  # noqa: F401
      import gymnasium as gym
      import imageio
      import lerobot
      import numpy as np
      import torch
      from lerobot.common.policies.act.modeling_act import ACTConfig, ACTPolicy
      from lerobot.common.policies.pi0.modeling_pi0 import PI0Policy
    
      try:
          # NOTE: git clone https://github.com/DepthAnything/Depth-Anything-V2 and fix path your appropriate directory.
          depth_anything_path = os.path.join(
              os.path.dirname(__file__), "..", "Depth-Anything-V2"
          )
          import sys
    
          sys.path.append(depth_anything_path)
          from depth_anything_v2.dpt import DepthAnythingV2
      except:
          print("If you want to use depth map, please install Depth-Anything-V2.")
    
      def add_occlusion(
          image, start_x, start_y, occlusion_height, occlusion_width, alpha=1.0
      ):
          overlay = image.copy()
          cv2.rectangle(
              overlay,
              (start_x, start_y),
              (start_x + occlusion_width, start_y + occlusion_height),
              (0, 0, 0),
              -1,
          )
          return cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)
    
      if __name__ == "__main__":
          parser = argparse.ArgumentParser()
          parser.add_argument("--model_type", type=str, default="act", choices=["act", "pi0"])
          parser.add_argument("--output_dir", type=str, default="outputs/eval/act_aloha")
          parser.add_argument(
              "--load_checkpoint_dir",
              type=str,
              default="../checkpoints/act-aloha-wo-dataaug-20250614/checkpoints/020000/pretrained_model",
          )
          parser.add_argument("--num_episodes", type=int, default=50)
          parser.add_argument("--max_episode_steps", type=int, default=500)
          parser.add_argument("--fix_seed", action="store_true")
          parser.add_argument("--seed", type=int, default=8)
          parser.add_argument("--gpu_id", type=int, default=0)
          parser.add_argument("--normalize_img", action="store_true", default=True)
          parser.add_argument("--depth_model_checkpoint_path", type=str, default=None)
          parser.add_argument("--occlusion", action="store_true")
          parser.add_argument("--occlusion_shuffle", action="store_true")
          parser.add_argument("--occlusion_x", type=int, default=250)
          parser.add_argument("--occlusion_y", type=int, default=230)
          parser.add_argument("--occlusion_w", type=int, default=75)
          parser.add_argument("--occlusion_h", type=int, default=75)
          parser.add_argument("--occlusion_alpha", type=float, default=1.0)
          parser.add_argument("--blur", action="store_true")
          parser.add_argument("--blur_kernel_size", type=int, default=15)
          args = parser.parse_args()
          for arg in vars(args):
              print(f"{arg}: {getattr(args, arg)}")
    
          os.makedirs(args.output_dir, exist_ok=True)
    
          if args.gpu_id < 0:
              device = "cpu"
          else:
              device = "cuda"
              os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)
    
          if args.fix_seed:
              np.random.seed(args.seed)
              torch.manual_seed(args.seed)
    
          # Define simulation environment with AlohaInsertion-v0
          os.environ["MUJOCO_GL"] = "egl"
    
          env = gym.make(
              "gym_aloha/AlohaInsertion-v0",
              obs_type="pixels_agent_pos",
              max_episode_steps=args.max_episode_steps,
          )
          print("env.observation_space:", env.observation_space)
          print("env.action_space:", env.action_space)
    
          # Load model (policy)
          if args.model_type == "act":
              policy = ACTPolicy.from_pretrained(
                  args.load_checkpoint_dir,
                  strict=False,
              )
          elif args.model_type == "pi0":
              policy = PI0Policy.from_pretrained(
                  args.load_checkpoint_dir,
                  strict=False,
              )
    
          policy.reset()
          print("Policy config:", vars(policy.config))
          print("policy.config.input_features:", policy.config.input_features)
          print("policy.config.output_features:", policy.config.output_features)
    
          # Load depth map preprocessing model
          if (
              args.depth_model_checkpoint_path is not None
              and args.depth_model_checkpoint_path != ""
          ):
              model_configs = {
                  "vits": {
                      "encoder": "vits",
                      "features": 64,
                      "out_channels": [48, 96, 192, 384],
                  },
                  "vitb": {
                      "encoder": "vitb",
                      "features": 128,
                      "out_channels": [96, 192, 384, 768],
                  },
                  "vitl": {
                      "encoder": "vitl",
                      "features": 256,
                      "out_channels": [256, 512, 1024, 1024],
                  },
                  "vitg": {
                      "encoder": "vitg",
                      "features": 384,
                      "out_channels": [1536, 1536, 1536, 1536],
                  },
              }
              encoder = "vitb"
    
              depth_model = DepthAnythingV2(**model_configs[encoder])
              depth_model.load_state_dict(
                  torch.load(args.depth_model_checkpoint_path, map_location="cpu")
              )
              depth_model = depth_model.to(device).eval()
    
          # -----------------------------------------------
          # Infer policy with simulation environment
          # -----------------------------------------------
          num_success = 0
          num_failure = 0
          occlusion_x = args.occlusion_x
          occlusion_y = args.occlusion_y
          occlusion_h = args.occlusion_h
          occlusion_w = args.occlusion_w
          occlusion_alpha = args.occlusion_alpha
    
          for episode in range(args.num_episodes):
              policy.reset()
              if args.fix_seed:
                  observation_np, info = env.reset(seed=args.seed)
              else:
                  observation_np, info = env.reset()
    
              rewards = []
              frames = []
              frames_depth = []
              step = 0
              done = False
    
              # Render initial frame
              frame = env.render()
              if args.occlusion:
                  frame = add_occlusion(
                      frame,
                      occlusion_x,
                      occlusion_y,
                      occlusion_h,
                      occlusion_w,
                      occlusion_alpha,
                  )
              if args.blur:
                  frame = cv2.GaussianBlur(
                      frame, (args.blur_kernel_size, args.blur_kernel_size), 0
                  )
    
              frames.append(frame)
    
              while not done:
                  # aloha environment has x-y position of the agent as the observation
                  state = torch.from_numpy(observation_np["agent_pos"]).to(device)
                  state = state.to(torch.float32)
                  state = state.unsqueeze(0)
    
                  # aloha environment has RGB image of the environment as the observation
                  image = torch.from_numpy(observation_np["pixels"]["top"]).to(device)
                  if args.normalize_img:
                      image = image.to(torch.float32) / 255
                  image = image.permute(2, 0, 1)
                  image = image.unsqueeze(0)
    
                  image_np = image.squeeze(0).permute(1, 2, 0).cpu().numpy()
                  if args.normalize_img:
                      image_np = (image_np * 255).astype(np.uint8)
    
                  # add some occlusion mask to the env image
                  if args.occlusion:
                      image_np = add_occlusion(
                          image_np,
                          args.occlusion_x,
                          args.occlusion_y,
                          args.occlusion_h,
                          args.occlusion_w,
                          args.occlusion_alpha,
                      )
                  if args.blur:
                      image_np = cv2.GaussianBlur(
                          image_np, (args.blur_kernel_size, args.blur_kernel_size), 0
                      )
    
                  image = torch.from_numpy(image_np).to(device)
                  if args.normalize_img:
                      image = image.to(torch.float32) / 255
                  image = image.permute(2, 0, 1)
                  image = image.unsqueeze(0)
                  # cv2.imwrite(f"{args.output_dir}/env_image.png", cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR))
    
                  # aloha insert task expects the following observation format
                  observation = {
                      # agent's x-y position
                      "observation.state": state,
                      # environment's RGB image
                      "observation.images.top": image,
                      # agent's control instruction text
                      "task": ["Insert the peg into the socket"],
                  }
    
                  # add depth map to the observation
                  if (
                      args.depth_model_checkpoint_path is not None
                      and args.depth_model_checkpoint_path != ""
                  ):
                      depth_map = depth_model.infer_image(image_np)
                      depth_map_vis = ((depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) * 255).astype(np.uint8)
                      depth_map_vis = np.stack([depth_map_vis, depth_map_vis, depth_map_vis], axis=-1)
                      frames_depth.append(depth_map_vis)
                      # cv2.imwrite(f"{args.output_dir}/env_depth_map.png", cv2.cvtColor(depth_map_vis, cv2.COLOR_RGB2BGR))
    
                      if args.normalize_img:
                          depth_map = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min())
    
                      depth_map = np.stack([depth_map, depth_map, depth_map], axis=0)
                      depth_map = torch.from_numpy(depth_map).unsqueeze(0).to(device)
                      observation["observation.depths.top"] = depth_map
    
                  if episode == 0 and step == 0:
                      for key in observation:
                          if isinstance(observation[key], torch.Tensor) or isinstance(
                              observation[key], np.ndarray
                          ):
                              print(
                                  f"[observation.{key}] shape={observation[key].shape}, min={observation[key].min()}, max={observation[key].max()}, dtype={observation[key].dtype}"
                              )
    
                  # infer the next action based on the policy
                  with torch.inference_mode():
                      action = policy.select_action(observation)
                      if episode == 0 and step == 0:
                          print(
                              f"[action] shape={action.shape}, min={action.min()}, max={action.max()}, dtype={action.dtype}"
                          )
    
                  # step through the simulation environment and receive a new observation
                  action_np = action.squeeze(0).to("cpu").numpy()
                  if episode == 0 and step == 0:
                      print(
                          f"[action_np] shape={action_np.shape}, min={action_np.min()}, max={action_np.max()}, dtype={action_np.dtype}"
                      )
    
                  observation_np, reward, terminated, truncated, info = env.step(action_np)
                  print(f"{step=} {reward=} {terminated=}")
    
                  # render the environment
                  frame = env.render()
                  if args.occlusion:
                      frame = add_occlusion(
                          frame,
                          occlusion_x,
                          occlusion_y,
                          occlusion_h,
                          occlusion_w,
                          occlusion_alpha,
                      )
                  if args.blur:
                      frame = cv2.GaussianBlur(
                          frame, (args.blur_kernel_size, args.blur_kernel_size), 0
                      )
                  # cv2.imwrite(f"{args.output_dir}/env_frame.png", cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
                  frames.append(frame)
    
                  # keep track of all the rewards
                  rewards.append(reward)
    
                  # finish inference when the success state is reached (i.e. terminated is True),
                  # or the maximum number of iterations is reached (i.e. truncated is True)
                  done = terminated | truncated | done
                  step += 1
    
              if terminated:
                  print("Success!")
                  num_success += 1
              else:
                  print("Failure!")
                  num_failure += 1
    
              # save the simulation frames as a video
              if terminated:
                  video_name = f"eval_frames_ep{episode}_ok.mp4"
              else:
                  video_name = f"eval_frames_ep{episode}_ng.mp4"
              video_path = os.path.join(args.output_dir, video_name)
              imageio.mimsave(str(video_path), np.stack(frames), fps=env.metadata["render_fps"])
              print(f"Video of the evaluation is available in '{video_path}'.")
    
              if (
                  args.depth_model_checkpoint_path is not None
                  and args.depth_model_checkpoint_path != ""
              ):
                  if terminated:
                      video_name = f"eval_frames_depth_ep{episode}_ok.mp4"
                  else:
                      video_name = f"eval_frames_depth_ep{episode}_ng.mp4"
                  video_path = os.path.join(args.output_dir, video_name)
                  imageio.mimsave(str(video_path), np.stack(frames_depth), fps=env.metadata["render_fps"])
    
              # shuffle occlusion position and size for next episode
              if args.occlusion_shuffle:
                  x_range = np.arange(-100, 101, 10)
                  y_range = np.arange(-100, 101, 10)
                  h_range = np.arange(-10, 51, 10)
                  w_range = np.arange(-10, 51, 10)
    
                  occlusion_x = args.occlusion_x + np.random.choice(x_range)
                  occlusion_y = args.occlusion_y + np.random.choice(y_range)
                  occlusion_h = args.occlusion_h + np.random.choice(h_range)
                  occlusion_w = args.occlusion_w + np.random.choice(w_range)
    
          print(f"Success rate: {num_success / args.num_episodes * 100:.2f}%")
          print(f"Failure rate: {num_failure / args.num_episodes * 100:.2f}%")
    
          # write evaluation results to a file
          with open(os.path.join(args.output_dir, "eval_results.txt"), "w") as f:
              for arg in vars(args):
                  f.write(f"{arg}: {getattr(args, arg)}\n")
              f.write(f"Success rate: {num_success / args.num_episodes * 100:.2f}%\n")
              f.write(f"Failure rate: {num_failure / args.num_episodes * 100:.2f}%\n")
    

まとめ

今回の記事ではロボティクスモデルの改善を初めてやってみましたが、ロボティクス領域という自分に取っては未知な領域においても AI モデルの基本的なモデル改善ノウハウがそのまま有効であることがわかったのは良かったです。

やはり AI の民主化時代にあっても機械学習モデルの基本を抑えておくのは大事なのだと改めて思いました。

We Are Hiring!

株式会社ABEJAでは共に働く仲間を募集しています!

ロボティクスやLLMに興味ある方々!機械学習プロダクトに関わるフロントエンド開発やバックエンド開発に興味ある方々! こちらの採用ページから是非ご応募くださいませ!

careers.abejainc.com