mxnet系列教程之1-第一個例子

第一個例子當然是mnist的例子

假設已經成功安裝了mxnet

例子的代碼如下:

cd mxnet/example/image-classification
python train_mnist.py
這樣就會運行下去

train_mnist.py的代碼為

"""
Train mnist, see more explanation at http://mxnet.io/tutorials/python/mnist.html
"""
import os
import argparse
import logging
logging.basicConfig(level=logging.DEBUG)
from common import find_mxnet, fit
from common.util import download_file
import mxnet as mx
import numpy as np
import gzip, structdef read_data(label, image):"""download and read data into numpy"""base_url = 'http://yann.lecun.com/exdb/mnist/'with gzip.open(download_file(base_url+label, os.path.join('data',label))) as flbl:magic, num = struct.unpack(">II", flbl.read(8))label = np.fromstring(flbl.read(), dtype=np.int8)with gzip.open(download_file(base_url+image, os.path.join('data',image)), 'rb') as fimg:magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)return (label, image)def to4d(img):"""reshape to 4D arrays"""return img.reshape(img.shape[0], 1, 28, 28).astype(np.float32)/255def get_mnist_iter(args, kv):"""create data iterator with NDArrayIter"""(train_lbl, train_img) = read_data('train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz')(val_lbl, val_img) = read_data('t10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz')train = mx.io.NDArrayIter(to4d(train_img), train_lbl, args.batch_size, shuffle=True)val = mx.io.NDArrayIter(to4d(val_img), val_lbl, args.batch_size)return (train, val)if __name__ == '__main__':# parse argsparser = argparse.ArgumentParser(description="train mnist",formatter_class=argparse.ArgumentDefaultsHelpFormatter)parser.add_argument('--num-classes', type=int, default=10,help='the number of classes')parser.add_argument('--num-examples', type=int, default=60000,help='the number of training examples')fit.add_fit_args(parser)parser.set_defaults(# networknetwork        = 'mlp',# traingpus           = '0,1',batch_size      = 64,disp_batches = 100,num_epochs     = 20,lr             = .05,lr_step_epochs = '10',model_predix = './my')args = parser.parse_args()# load networkfrom importlib import import_modulenet = import_module('symbols.'+args.network)sym = net.get_symbol(**vars(args))# trainfit.fit(args, sym, get_mnist_iter)

net寫在了symbol文件夾中,相當于caffe的prototxt文件

model_predix相當于caffe的保存模型前綴

打印的信息為Saved checkpoint to"./my-0001.params"

"""
a simple multilayer perceptron
"""
import mxnet as mxdef get_symbol(num_classes=10, **kwargs):data = mx.symbol.Variable('data')data = mx.sym.Flatten(data=data)fc1  = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")fc2  = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")fc3  = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=num_classes)mlp  = mx.symbol.SoftmaxOutput(data = fc3, name = 'softmax')return mlp

fit.py里面寫了運行的代碼

import mxnet as mx
import logging
import os
import timedef _get_lr_scheduler(args, kv):if 'lr_factor' not in args or args.lr_factor >= 1:return (args.lr, None)epoch_size = args.num_examples / args.batch_sizeif 'dist' in args.kv_store:epoch_size /= kv.num_workersbegin_epoch = args.load_epoch if args.load_epoch else 0step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]lr = args.lrfor s in step_epochs:if begin_epoch >= s:lr *= args.lr_factorif lr != args.lr:logging.info('Adjust learning rate to %e for epoch %d' %(lr, begin_epoch))steps = [epoch_size * (x-begin_epoch) for x in step_epochs if x-begin_epoch > 0]return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor))def _load_model(args, rank=0):if 'load_epoch' not in args or args.load_epoch is None:return (None, None, None)assert args.model_prefix is not Nonemodel_prefix = args.model_prefixif rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)):model_prefix += "-%d" % (rank)sym, arg_params, aux_params = mx.model.load_checkpoint(model_prefix, args.load_epoch)logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)return (sym, arg_params, aux_params)def _save_model(args, rank=0):if args.model_prefix is None:return Nonedst_dir = os.path.dirname(args.model_prefix)if not os.path.isdir(dst_dir):os.mkdir(dst_dir)return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % (args.model_prefix, rank))def add_fit_args(parser):"""parser : argparse.ArgumentParserreturn a parser added with args required by fit"""train = parser.add_argument_group('Training', 'model training')train.add_argument('--network', type=str,help='the neural network to use')train.add_argument('--num-layers', type=int,help='number of layers in the neural network, required by some networks such as resnet')train.add_argument('--gpus', type=str,help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu')train.add_argument('--kv-store', type=str, default='device',help='key-value store type')train.add_argument('--num-epochs', type=int, default=100,help='max num of epochs')train.add_argument('--lr', type=float, default=0.1,help='initial learning rate')train.add_argument('--lr-factor', type=float, default=0.1,help='the ratio to reduce lr on each step')train.add_argument('--lr-step-epochs', type=str,help='the epochs to reduce the lr, e.g. 30,60')train.add_argument('--optimizer', type=str, default='sgd',help='the optimizer type')train.add_argument('--mom', type=float, default=0.9,help='momentum for sgd')train.add_argument('--wd', type=float, default=0.0001,help='weight decay for sgd')train.add_argument('--batch-size', type=int, default=128,help='the batch size')train.add_argument('--disp-batches', type=int, default=20,help='show progress for every n batches')train.add_argument('--model-prefix', type=str,help='model prefix')parser.add_argument('--monitor', dest='monitor', type=int, default=0,help='log network parameters every N iters if larger than 0')train.add_argument('--load-epoch', type=int,help='load the model on an epoch using the model-load-prefix')train.add_argument('--top-k', type=int, default=0,help='report the top-k accuracy. 0 means no report.')train.add_argument('--test-io', type=int, default=0,help='1 means test reading speed without training')return traindef fit(args, network, data_loader, **kwargs):"""train a modelargs : argparse returnsnetwork : the symbol definition of the nerual networkdata_loader : function that returns the train and val data iterators"""# kvstorekv = mx.kvstore.create(args.kv_store)# logginghead = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'logging.basicConfig(level=logging.DEBUG, format=head)logging.info('start with arguments %s', args)# data iterators(train, val) = data_loader(args, kv)if args.test_io:tic = time.time()for i, batch in enumerate(train):for j in batch.data:j.wait_to_read()if (i+1) % args.disp_batches == 0:logging.info('Batch [%d]\tSpeed: %.2f samples/sec' % (i, args.disp_batches*args.batch_size/(time.time()-tic)))tic = time.time()return# load modelif 'arg_params' in kwargs and 'aux_params' in kwargs:arg_params = kwargs['arg_params']aux_params = kwargs['aux_params']else:sym, arg_params, aux_params = _load_model(args, kv.rank)if sym is not None:assert sym.tojson() == network.tojson()# save modelcheckpoint = _save_model(args, kv.rank)# devices for trainingdevs = mx.cpu() if args.gpus is None or args.gpus is '' else [mx.gpu(int(i)) for i in args.gpus.split(',')]# learning ratelr, lr_scheduler = _get_lr_scheduler(args, kv)# create modelmodel = mx.mod.Module(context       = devs,symbol        = network)lr_scheduler  = lr_scheduleroptimizer_params = {'learning_rate': lr,'momentum' : args.mom,'wd' : args.wd,'lr_scheduler': lr_scheduler}monitor = mx.mon.Monitor(args.monitor, pattern=".*") if args.monitor > 0 else Noneinitializer   = mx.init.Xavier(rnd_type='gaussian', factor_type="in", magnitude=2)# initializer   = mx.init.Xavier(factor_type="in", magnitude=2.34),# evaluation metriceseval_metrics = ['accuracy']if args.top_k > 0:eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=args.top_k))# callbacks that run after each batchbatch_end_callbacks = [mx.callback.Speedometer(args.batch_size, args.disp_batches)]if 'batch_end_callback' in kwargs:cbs = kwargs['batch_end_callback']batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]# runmodel.fit(train,begin_epoch        = args.load_epoch if args.load_epoch else 0,num_epoch          = args.num_epochs,eval_data          = val,eval_metric        = eval_metrics,kvstore            = kv,optimizer          = args.optimizer,optimizer_params   = optimizer_params,initializer        = initializer,arg_params         = arg_params,aux_params         = aux_params,batch_end_callback = batch_end_callbacks,epoch_end_callback = checkpoint,allow_missing      = True,monitor            = monitor)





轉載于:https://www.cnblogs.com/hellokittyblog/p/9128451.html

本文來自互聯網用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權,不承擔相關法律責任。
如若轉載,請注明出處:http://www.pswp.cn/news/372360.shtml
繁體地址,請注明出處:http://hk.pswp.cn/news/372360.shtml
英文地址,請注明出處:http://en.pswp.cn/news/372360.shtml

如若內容造成侵權/違法違規/事實不符,請聯系多彩編程網進行投訴反饋email:809451989@qq.com,一經查實,立即刪除!

相關文章

Apache Shiro第3部分–密碼學

除了保護網頁和管理訪問權限外, Apache Shiro還執行基本的加密任務。 該框架能夠: 加密和解密數據, 哈希數據, 生成隨機數。 Shiro沒有實現任何加密算法。 所有計算都委托給Java密碼學擴展(JCE)API。 使…

mysql數據存在就更新_Mysql:如果數據存在則更新,不存在則插入

mysql語法支持如果數據存在則更新,不存在則插入,首先判斷數據存在還是不存在的那個字段要設置成unique索引,例如表tb_addrbook如下:索引:語句1:不存在插入INSERT INTO tb_addrbook(num,name,mobile) VALUE(1001,小李,1…

Memcached, Redis, MongoDB區別

mongodb和memcached不是一個范疇內的東西。mongodb是文檔型的非關系型數據庫,其優勢在于查詢功能比較強大,能存儲海量數據。mongodb和memcached不存在誰替換誰的問題。和memcached更為接近的是redis。它們都是內存型數據庫,數據保存在內存中&…

洛谷P1757 通天之分組背包 [2017年4月計劃 動態規劃06]

P1757 通天之分組背包 題目背景 直達通天路小A歷險記第二篇 題目描述 自01背包問世之后,小A對此深感興趣。一天,小A去遠游,卻發現他的背包不同于01背包,他的物品大致可分為k組,每組中的物品相互沖突,現在&a…

c3p0 0.9.1.2 配套mysql_連接數據庫,使用c3p0技術連接MySQL數據庫

讀取配置文件連接MySQL數據庫先確認已經導入了 mysql 的驅動包db.propertiesdrivercom.mysql.jdbc.Driverurljdbc:mysql://localhost:3306/v20?useUnicodetrue&characterEncodingutf8usernamerootpassword123456JdbcUtil.javapackage com.stu_mvc.utils;import java.io.Fi…

【Hadoop】Hadoop MR 自定義分組 Partition機制

1、概念 2、Hadoop默認分組機制--所有的Key分到一個組,一個Reduce任務處理 3、代碼示例 FlowBean package com.ares.hadoop.mr.flowgroup;import java.io.DataInput; import java.io.DataOutput; import java.io.IOException;import org.apache.hadoop.io.WritableC…

Spring Framework 3.2 M1發布

SpringSource剛剛宣布了針對Spring 3.2的第一個里程碑版本。 現在可以從SpringSource存儲庫(位于http://repo.springsource.org/)獲得新版本。 查看有關通過Maven 解決這些工件的快速教程 。 此版本包括: 最初支持異步Controller方法 早期…

兩種動態SQL

參考:http://www.cnblogs.com/wanyuan8/archive/2011/11/09/2243483.htmlhttp://www.cnblogs.com/xbf321/archive/2008/11/02/1325067.html 兩種動態SQL  1. EXEC (sql)   2. EXEC sp_executesql 性能:sp_executesql提供了輸入輸出接口,更…

mysql查詢含有某個值的表_MYSQL查詢數據表中某個字段包含某個數值

當某個字段中字符串是"1,2,3,4,5,6"或者"123456" 查詢數據表中某個字段是否包含某個值 1:模糊查詢 使用like select * from 表 where 字段 like %1%; 2:函數查找 find_in_set(str,數組) select * from 表 where find_in_set(1,字段); 注意:mysql字符串…

android學習筆記35——AnimationDrawable資源

AnimationDrawable資源 AnimationDrawable,代表一個動畫。 android既支持傳統的逐幀動畫(類似于電影方式,一張圖片一張圖片的切換),也支持通過平移、變換計算出來的補間動畫、屬性動畫。 下面以補間動畫為例,介紹如何定義Animatio…

RESTEasy教程第2部分:Spring集成

RESTEasy提供了對Spring集成的支持&#xff0c;這使我們能夠將Spring bean作為RESTful WebServices公開。 步驟&#xff03;1&#xff1a;使用Maven配置RESTEasy Spring依賴項。 <project xmlnshttp:maven.apache.orgPOM4.0.0 xmlns:xsihttp:www.w3.org2001XMLSchema-insta…

java RSA 加簽驗簽【轉】

引用自: http://blog.csdn.net/wangqiuyun/article/details/42143957/ java RSA 加簽驗簽 package com.testdemo.core.service.impl.alipay;import java.security.KeyFactory; import java.security.PrivateKey; import java.security.PublicKey; import java.security.spec.PK…

mysql啟動時執行sql server_常見 mysql 啟動、運行.sql 文件錯誤處理

1、mysql 啟動錯誤處理查看 log&#xff1a;Mac: /usr/local/var/mysql/lizhendeMacBook-Pro.local.err根據 log 針對性的進行調整&#xff0c;包治百病2、Mysql Incorrect datetime value問題描述&#xff1a;低版本的 mysql 中&#xff0c;數據庫轉儲 sql 文件。導入到高版本…

帶有謂詞的Java中的函數樣式-第2部分

在本文的第一部分中&#xff0c;我們介紹了謂詞&#xff0c;這些謂詞通過具有返回true或false的單個方法的簡單接口&#xff0c;為Java等面向對象的語言帶來了函數式編程的某些好處。 在第二部分和最后一部分中&#xff0c;我們將介紹一些更高級的概念&#xff0c;以使您的謂詞…

Devxtreme 顯示Master-Detail數據列表, 數據顯示顏色

1 ////刷新3/4簇Grid2 //function GetClusterGrid(id, coverageId, clusterId) {3 4 // var region getRegionCityName();5 // $.ajax({6 // type: "POST",7 // url: "fast_index_overview.aspx/GetClusterGrid&q…

mysql 排序去重復_php mysql 過濾重復記錄并排序

table1idname1a2b3ctable2idnamecont1aaa2bbb3aaaaaSELECT*,count(distincttable2.name)FROMtable1,table2WHEREtable1.nametable2.nameGROUPBYtable2.nameORDERBYtable2.idDESC";重復...table1id name1 a2 b3 ctable2id name cont1 a aa2 b bb3 a aaaaSELECT *,count(dis…

Java EE 6測試第I部分– EJB 3.1可嵌入API

我們從Enterprise JavaBeans開發人員那里聽到的最常見的請求之一就是需要改進的單元/集成測試支持。 EJB 3.1 Specification引入了EJB 3.1 Embeddable API&#xff0c;用于在Java SE環境中執行EJB組件。 與傳統的基于Java EE服務器的執行不同&#xff0c;可嵌入式用法允許客戶端…

Flume 中文入門手冊

原文&#xff1a;https://cwiki.apache.org/confluence/display/FLUME/GettingStarted 什么是 Flume NG? Flume NG 旨在比起 Flume OG 變得明顯更簡單。更小。更easy部署。在這樣的情況下&#xff0c;我們不提交Flume NG 到 Flume OG 的后向兼容。當前。我們期待來自感興趣測試…

原生JavaScript+CSS3實現移動端滑塊效果

在做web頁面時&#xff0c;無論PC端還是移動端&#xff0c;我們會遇到滑塊這樣的效果&#xff0c;可能我們往往會想著去網上找插件&#xff0c;其實這個效果非常的簡單&#xff0c;插件代碼的的代碼往往過于臃腫&#xff0c;不如自己動手&#xff0c;自給自足。首先看一下效果圖…

mysql的連接名是哪個文件_mysql連接名是什么

{"moduleinfo":{"card_count":[{"count_phone":1,"count":1}],"search_count":[{"count_phone":4,"count":4}]},"card":[{"des":"阿里云數據庫專家保駕護航&#xff0c;為用戶…