# stable_meta_learning **Repository Path**: loxs/stable_meta_learning ## Basic Information - **Project Name**: stable_meta_learning - **Description**: meta learning using stable baseline3 - **Primary Language**: Unknown - **License**: GPL-2.0 - **Default Branch**: master - **Homepage**: None - **GVP Project**: No ## Statistics - **Stars**: 0 - **Forks**: 0 - **Created**: 2023-10-31 - **Last Updated**: 2023-12-01 ## Categories & Tags **Categories**: Uncategorized **Tags**: None ## README # **meta learning using stable baseline3** ## Install ```bash git clone https://gitee.com/loxs/stable_meta_learning.git pip install -e stable-meta-learning ``` ## Usage ### Implemented environments - panda_gym: https://github.com/qgallouedec/panda-gym ### Implemented methods - pearl: on arxiv http://arxiv.org/abs/1903.08254 - Referenced implementation of this repository: https://github.com/dongminlee94/meta-learning-for-everyone ```python # code for train import torch from stable_baselines3.common.vec_env import DummyVecEnv from stable_meta_learning.pearl import PEARL_SAC from stable_meta_learning.envs import make_env from stable_meta_learning.pearl import MultiTaskHerReplayBuffer train_env_1 = make_env('PandaPush-v3',lateral_friction=0.5) train_env_2 = make_env('PandaPush-v3',lateral_friction=1) train_env_3 = make_env('PandaPush-v3',lateral_friction=1.5) train_env_4 = make_env('PandaPush-v3',lateral_friction=2) train_env = DummyVecEnv([train_env_1,train_env_2,train_env_3,train_env_4]) # SAC train model model = PEARL_SAC(encoder_lr = 1e-4, latent_dim = 5, encoder_hidden_dim = 300, kl_lambda=0.1, policy = "MultiInputPolicy", env = train_env, batch_size=2048, gamma=0.95, learning_rate=1e-4, train_freq=64, gradient_steps=64, tau=0.05, replay_buffer_class=MultiTaskHerReplayBuffer, replay_buffer_kwargs=dict( n_sampled_goal=4, goal_selection_strategy="future", ), policy_kwargs=dict( net_arch=[512, 512, 512], n_critics=2, ), learning_starts = 1000, verbose=1, device=torch.device('cuda:0')) model.learn(total_timesteps=2_000_000,progress_bar=True) model.save('checkpoints/PEARL_SAC') train_env.close() ``` ```python # code for test from stable_baselines3.common.vec_env import DummyVecEnv import numpy as np import csv import torch from stable_meta_learning.pearl import PEARL_SAC from stable_meta_learning.envs import make_env,get_state,states_to_result test_env_1 = make_env('PandaPush-v3',lateral_friction=10.0) test_env = DummyVecEnv([test_env_1]) model = PEARL_SAC.load('meta-learning/stable_meta_learning/checkpoints/PEARL_SAC3.zip', env=test_env, device=torch.device('cuda:0'), ) out_csv = 'train1234-test10.csv' model.encoder.clear_z() observations = test_env.reset() actions = None rewards = None states = [] success_num = 0 pickandplace_num = 0 roll_num = 0 push_num = 0 for _ in range(100): _states = [] while True: # test_env.render(mode='human') _states.append(get_state(test_env)) if actions is not None: obs = [] for key in observations: if key == 'task_z': continue obs.append(observations[key]) obs = np.concatenate(obs,axis=-1) model.update_context(obs, actions, rewards) model.encoder.infer_posterior(model.encoder.context) observations['task_z'] = model.encoder.task_z.cpu().detach().numpy() actions, _ = model.predict( observations, # type: ignore[arg-type] deterministic=True, ) observations, rewards, dones, infos = test_env.step(actions) if dones[0]: # print('->'.join(states),end=' ') if infos[0]['is_success']: _states.append('success') print('success') success_num += 1 else: _states.append('fail') print('fail') break states.append(_states) with open(out_csv, 'w', newline='',encoding='utf-8') as file: writer = csv.writer(file) for _states in states: writer.writerow(_states) result = states_to_result(_states) if result == 'push': push_num += 1 elif result == 'roll': roll_num += 1 else: pickandplace_num += 1 print('success_rate:', success_num / 100) print('pickandplace_rate:',pickandplace_num / 100) print('roll_rate:',roll_num / 100) print('push_rate:',push_num / 100) test_env.close() ```