PyTorch
PyTorchは、特に研究者の間で、Pythonにおけるディープラーニングのための最も人気のあるフレームワークの一つです。W&Bは、PyTorchに対して、勾配のログ取得からCPUとGPUのコードプロファイリングまで、最高レベルのサポートを提供しています。
私たちのインテグレーションをColabノートブックでお試しいただくか(ビデオ解説付き)、スクリプトが含まれているexampleリポジトリをご覧ください。これには、Fashion MNISTにHyperbandを使用したハイパーパラメータ最適化に関するものが含まれています。さらに、それが生成するW&B ダッシュボードもご覧いただけます。
wandb.watch
を使った勾配の記録
自動的に勾配をログに記録するには、wandb.watch
を呼び出して、PyTorchモデルを渡します。
import wandb
wandb.init(config=args)
model = ... # モデルの設定
# Magic
wandb.watch(model, log_freq=100)
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
wandb.log({"loss": loss})
同じスクリプト内で複数のモデルをトラッキングする必要がある場合は、それぞれのモデルでwandb.watch
を呼び出すことができます。この関数のリファレンスドキュメントはこちらです。
勾配、メトリクス、およびグラフは、forwardパスとbackwardパスの後にwandb.log
が呼び出されるまでログに記録されません。
画像やメディアのログ記録
PyTorchのTensors
に画像データを渡すと、wandb.Image
とtorchvision
のユーティリティが自動的に画像に変換します。
images_t = ... # PyTorch Tensorsとして画像を生成または読み込み
wandb.log({"examples": [wandb.Image(im) for im in images_t]})
PyTorchや他のフレームワークでW&Bにリッチメディアをログに記録する方法については、media logging guideを参照してください。
また、メディアに情報を添えて記録したい場合(モデルの予測や派生したメトリクスなど)、wandb.Table
を使用してください。
my_table = wandb.Table()
my_table.add_column("image", images_t)
my_table.add_column("label", labels)
my_table.add_column("class_prediction", predictions_t)
# W&BにTableをログする
wandb.log({"mnist_predictions": my_table})
上のコードでこのようなテーブルが生成されます。このモデルは良さそうです!
データセットとモデルのログや可視化について詳しくは、W&B Tables のガイドを参照してください。
PyTorch コードのプロファイリング
W&Bは PyTorch Kinetoの Tensorboardプラグインと直接統合して、PyTorch コードのプロファイリングツール、CPU と GPU 通信の詳細の検査、ボトルネックの特定や最適化のためのツールを提供しています。
profile_dir = "path/to/run/tbprofile/"
profiler = torch.profiler.profile(
schedule=schedule, # スケジューリングの詳細についてはプロファイラのドキュメントを参照してください
on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
with_stack=True,
)
with profiler:
... # ここでプロファイリングしたいコードを実行
# 詳細な使用方法については、プロファイラのドキュメントを参照してください
# wandbアーティファクトを作成
profile_art = wandb.Artifact("trace", type="profile")
# pt.trace.jsonファイルをアーティファクトに追加
profile_art.add_file(glob.glob(profile_dir + ".pt.trace.json"))
# アーティファクトをログに記録
profile_art.save()
このColabで実例コードを確認して実行してください。
:::注意
インタラクティブなトレースビューアツールは、Chrome Trace Viewerをベースにしており、Chromeブラウザで最適に動作します。
:::